Source code for botiverse.bots.ConverseBot.ConverseBot

import torch
from transformers import AutoModelForSeq2SeqLM
from tqdm.auto import tqdm
import torch.optim as optim
from botiverse.preprocessors.Special.ConverseBot_Preprocessor.ConverseBot_Preprocessor import ConverseBot_Preprocessor
from botiverse.models.T5Model.T5Model import T5Model

[docs]class ConverseBot: '''An interface for the ConverseBot model which is a conversational model based on the Flan-T5 model''' def __init__(self, from_scratch=False): """ Initializes a ConverseBot instance and loads the Backend finetuning parameters, and optionally gets the training dataset if a frontend finetuning is desired. :param from_scratch: Boolean flag to indicate whether to load a model verstion that is made from scratch (recommended to be False) :type from_scratch: Boolean, optional :returns: None """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.from_scratch = from_scratch if self.from_scratch: # create a model instance self.model = T5Model() # load the Backend finetuning parameters self.model.load_state_dict(AutoModelForSeq2SeqLM.from_pretrained("MohamedSaad/T5_ConverseBot").state_dict()) else: self.model = AutoModelForSeq2SeqLM.from_pretrained("MohamedSaad/T5_ConverseBot") # move the model to the GPU if available self.model.to(self.device) # load the preprocessor self.preprocessor = ConverseBot_Preprocessor()
[docs] def read_data(self, file_path): """ Reads and pre-processes the data, sets up the model based on the data and prepares the train-validation split. :param file_path: The path to the file that contains the dataset. :type file_path: str :returns: None """ # read data self.preprocessor = ConverseBot_Preprocessor(file_path=file_path) # process data self.data = self.preprocessor.process() # train validation split self.train_data = self.data.sample(frac=0.99, random_state=0) self.validation_data = self.data.drop(self.train_data.index) self.train_data = self.train_data.reset_index(drop=True) self.validation_data = self.validation_data.reset_index(drop=True)
[docs] def train(self, epochs=1, batch_size=1): """ Trains the model on the input dataset. :param epochs: Number of epochs to train for. :type epochs: int, optional :param batch_size: The size of the training batches. :type batch_size: int, optional :returns: None """ self.model.train() self.optimizer = optim.Adam(self.model.parameters(), lr=0.00005) for epoch in range(epochs): for i in tqdm(range(0, len(self.train_data), batch_size)): self.model.zero_grad() # prepare the training batches batch_text_input_ids = torch.concat(self.train_data['text_input_ids'][i:i+batch_size].tolist()).to(self.device) batch_text_attention_mask = torch.concat(self.train_data['text_attention_mask'][i:i+batch_size].tolist()).to(self.device) batch_labels = torch.concat(self.train_data['target'][i:i+batch_size].tolist()).to(self.device).to(self.device) if self.from_scratch: loss = self.model(input_ids=batch_text_input_ids, attention_mask=batch_text_attention_mask[0], decoder_input_ids=batch_labels)[1] else: loss = self.model(input_ids=batch_text_input_ids, attention_mask=batch_text_attention_mask, labels=batch_labels).loss loss.backward() self.optimizer.step() print("Epoch: " + str(epoch) + " Loss: " + str(loss.item()))
[docs] def validation(self, batch_size=1): """ Validates the model on the validation dataset. :param batch_size: The size of the validation batches. :type batch_size: int, optional :returns: None """ total = 0 loss = 0 self.model.eval() with torch.no_grad(): for i in tqdm(range(0, len(self.validation_data), batch_size)): # prepare the validation batches batch_text_input_ids = torch.concat(self.validation_data['text_input_ids'][i:i+batch_size].tolist()).to(self.device) batch_text_attention_mask = torch.concat(self.validation_data['text_attention_mask'][i:i+batch_size].tolist()).to(self.device) batch_labels = torch.concat(self.validation_data['target'][i:i+batch_size].tolist()).to(self.device) outputs = self.model(input_ids=batch_text_input_ids, attention_mask=batch_text_attention_mask, labels=batch_labels) loss += outputs.loss.item() total += batch_labels.size(0) print('Validation Loss: ', loss/total)
[docs] def infer(self, string, temperature=1): """ Inference on the model using the input string. :param string: The string to provide for inference. :type string: str :param temperature: The temperature of the softmax function, the higher its value the flatter the probability distribution of the next token will be. :type temperature: float, optional :returns: Inference result from the model. :rtype: str """ self.model.eval() token_obj = self.preprocessor.process_string(string) input_ids= token_obj['input_ids'].to(self.device) attention_mask = token_obj['attention_mask'].to(self.device) if self.from_scratch: output_tokens = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=250, temperature=temperature) return self.preprocessor.decode_tokens(output_tokens) else: output_tokens = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=250, num_beams=5, early_stopping=True, temperature=temperature, repetition_penalty=1.5) return self.preprocessor.decode_tokens(output_tokens[0])
# save a model locally
[docs] def save(self, path): """ Save the model locally to the provided path. :param path: The path where to save the model. :type path: str :returns: None """ torch.save(self.model.state_dict(), path)
# load a model from a local path
[docs] def load(self, path): """ Load the model from the provided path. :param path: The path where to load the model from. :type path: str :returns: None """ self.model.load_state_dict(torch.load(path))