Source code for botiverse.bots.WhizBot.WhizBot_GRU
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.auto import tqdm
from botiverse.models.GRUClassifier.GRUClassifier import GRUTextClassifier
from botiverse.preprocessors.Special.WhizBot_GRU_Preprocessor.WhizBot_GRU_Preprocessor import WhizBot_GRU_Preprocessor
import random
[docs]class WhizBot_GRU:
'''An interface for the WhizBot_GRU model which is based on simple GRU model'''
def __init__(self):
"""
Initializes WhizBot_GRU, and will prepare the GPU device based on CUDA availability.
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
[docs] def read_data(self, file_path):
"""
Reads and pre-processes the data, sets up the model based on the data and prepares the train-validation split.
:param file_path: The path to the file that contains the dataset.
:type file_path: str
:returns: None
"""
# read data
self.preprocessor = WhizBot_GRU_Preprocessor(file_path)
# process data
self.data = self.preprocessor.process()
# prepare model
vocab_size = len(self.preprocessor.tokenizer.get_vocab())
num_labels = len(self.preprocessor.label_dict)
self.model = GRUTextClassifier(vocab_size, 300, num_labels).to(self.device)
self.criterion = nn.CrossEntropyLoss()
self.optimizer = optim.SGD(self.model.parameters(), lr=0.05)
# train validation split
self.train_data = self.data.sample(frac=0.8, random_state=0)
self.validation_data = self.data.drop(self.train_data.index)
self.train_data = self.train_data.reset_index(drop=True)
self.validation_data = self.validation_data.reset_index(drop=True)
[docs] def train(self, epochs=50, batch_size=32):
"""
Trains the model using the training dataset.
:param epochs: The number of training epochs.
:type epochs: int
:param batch_size: The number of training examples utilized used to make one paramenters updat.
:type batch_size: int
:returns: None
"""
self.model.train()
for epoch in range(epochs):
for i in tqdm(range(0, len(self.train_data), batch_size)):
self.model.zero_grad()
# get the batches
batch_texts = torch.stack(self.train_data['text'][i:i+batch_size].tolist()).to(self.device)
batch_labels = torch.cat(self.train_data['label'][i:i+batch_size].tolist()).to(self.device)
output = self.model(batch_texts)
loss = self.criterion(output, batch_labels)
# backpropagate
loss.backward()
self.optimizer.step()
print("Epoch: " + str(epoch) + " Loss: " + str(loss.item()))
[docs] def validation(self, batch_size=32):
"""
Tests the model performance using the validation dataset and calculates the accuracy.
:param batch_size: The number of training examples utilized used to make one paramenters updat.
:type batch_size: int
:returns: None
"""
correct = 0
total = 0
self.model.eval()
with torch.no_grad():
for i in tqdm(range(0, len(self.validation_data), batch_size)):
# get the batches
batch_texts = torch.stack(self.validation_data['text'][i:i+batch_size].tolist()).to(self.device)
batch_labels = torch.cat(self.validation_data['label'][i:i+batch_size].tolist()).to(self.device)
outputs = self.model(batch_texts)
_, predicted = torch.max(outputs.data, 1)
# calculate accuracy
total += batch_labels.size(0)
correct += (predicted == batch_labels).sum().item()
print('Accuracy: %d %%' % (100 * correct / total))
[docs] def infer(self, string):
"""
Performs inference using the model.
:param string: The input string to perform inference on.
:type string: str
:returns: A random response from the response list of the predicted label.
"""
self.model.eval()
with torch.no_grad():
string = self.preprocessor.process_string(string)
string = string.unsqueeze(0).to(self.device)
output = self.model(string)
_, predicted = torch.max(output.data, 1)
# get the label
for key, value in self.preprocessor.label_dict.items():
if value == predicted.item():
label = key
break
# return a random responce of the label
return random.choice(self.preprocessor.responces[label])
[docs] def save(self, path):
"""
Saves the model parameters to the given path.
:param path: The path where the model parameters will be saved.
:type path: str
:returns: None
"""
torch.save(self.model.state_dict(), path)
[docs] def load(self, path):
"""
Loads the model parameters from the given path.
:param path: The path from where the model parameters will be loaded.
:type path: str
:returns: None
"""
self.model.load_state_dict(torch.load(path))