Source code for botiverse.bots.WhizBot.WhizBot
from botiverse.bots.WhizBot.WhizBot_GRU import WhizBot_GRU
from botiverse.bots.WhizBot.WhizBot_BERT import WhizBot_BERT
[docs]class WhizBot:
'''
A class that provides an interface for the WhizBot-BERT and WhizBot-GRU models.
'''
def __init__(self, repr='BERT'):
"""
Initializes WhizBot and sets its representation type.
:param repr: The representation type of the WhizBot model. Either "BERT" or "GRU".
:type repr: str
"""
if repr == 'BERT':
self.bot = WhizBot_BERT()
elif repr == 'GRU':
self.bot = WhizBot_GRU()
else:
raise ValueError('Invalid representation type for WhizBot. Please choose either "BERT" or "GRU".')
[docs] def read_data(self, file_path):
"""
Reads and pre-processes the data, sets up the model based on the data and prepares the train-validation split.
:param file_path: The path to the file that contains the dataset.
:type file_path: str
:returns: None
"""
self.bot.read_data(file_path)
[docs] def train(self, epochs=10, batch_size=32):
"""
Trains the model using the training dataset.
:param epochs: The number of training epochs.
:type epochs: int
:param batch_size: The number of training examples utilized used to make one paramenters updat.
:type batch_size: int
:returns: None
"""
self.bot.train(epochs, batch_size)
[docs] def validation(self, batch_size=32):
"""
Tests the model performance using the validation dataset and calculates the accuracy.
:param batch_size: The number of training examples utilized used to make one paramenters updat.
:type batch_size: int
:returns: None
"""
self.bot.validation(batch_size)
[docs] def infer(self, string):
"""
Performs inference using the model.
:param string: The input string to perform inference on.
:type string: str
:returns: A random response from the response list of the predicted label.
"""
return self.bot.infer(string)
[docs] def save(self, path):
"""
Saves the model parameters to the given path.
:param path: The path where the model parameters will be saved.
:type path: str
:returns: None
"""
self.bot.save(path)
[docs] def load(self, path):
"""
Loads the model parameters from the given path.
:param path: The path from where the model parameters will be loaded.
:type path: str
:returns: None
"""
self.bot.load(path)