"""
This Module has the run functions for TRIPPY that train and evaluate the model.
"""
import torch
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from botiverse.models.TRIPPY.data import prepare_data, Dataset
from botiverse.models.TRIPPY.train import train
from botiverse.models.TRIPPY.evaluate import eval
[docs]def run(model, domains, slot_list, label_maps, train_path, dev_path, test_path, device, non_referable_slots, non_referable_pairs, model_path, TRIPPY_config):
"""
Train and evaluate the TRIPPY model.
:param model: The TRIPPY model.
:type model: TRIPPY
:param domains: The domains to consider in the dataset.
:type domains: list
:param slot_list: The list of slots.
:type slot_list: list
:param label_maps: The mapping of slot values to their variants.
:type label_maps: dict
:param train_path: The path to the training dataset in JSON format.
:type train_path: str
:param dev_path: The path to the development dataset in JSON format.
:type dev_path: str
:param test_path: The path to the testing dataset in JSON format.
:type test_path: str
:param device: The device to train and evaluate the model on.
:type device: torch.device
:param non_referable_slots: The slots that are not referable.
:type non_referable_slots: list
:param non_referable_pairs: The pairs of slots that are not referable.
:type non_referable_pairs: list
:param model_path: The path to save the best model.
:type model_path: str
:param TRIPPY_config: The configuration for TRIPPY.
:type TRIPPY_config: TRIPPYConfig
"""
n_slots = len(slot_list)
# train
print('Preprocessing train set...')
train_raw_data, train_data = prepare_data(train_path, slot_list, label_maps, TRIPPY_config.tokenizer, TRIPPY_config.max_len, domains, non_referable_slots, non_referable_pairs, TRIPPY_config.multiwoz)
train_dataset = Dataset(train_data, n_slots, TRIPPY_config.oper2id, slot_list)
train_sampler = torch.utils.data.RandomSampler(train_dataset)
train_data_loader = torch.utils.data.DataLoader(train_dataset,
sampler=train_sampler,
batch_size=TRIPPY_config.train_batch_size)
# dev
if dev_path is not None:
print('Preprocessing dev set...')
dev_raw_data, dev_data = prepare_data(dev_path, slot_list, label_maps, TRIPPY_config.tokenizer, TRIPPY_config.max_len, domains, non_referable_slots, non_referable_pairs, TRIPPY_config.multiwoz)
dev_dataset = Dataset(dev_data, n_slots, TRIPPY_config.oper2id, slot_list)
dev_data_loader = torch.utils.data.DataLoader(dev_dataset,
batch_size=TRIPPY_config.dev_batch_size)
# test
if test_path is not None:
print('Preprocessing test set...')
test_raw_data, test_data = prepare_data(test_path, slot_list, label_maps, TRIPPY_config.tokenizer, TRIPPY_config.max_len, domains, non_referable_slots, non_referable_pairs, TRIPPY_config.multiwoz)
test_dataset = Dataset(test_data, n_slots, TRIPPY_config.oper2id, slot_list)
test_data_loader = torch.utils.data.DataLoader(test_dataset,
batch_size=TRIPPY_config.test_batch_size)
param_optimizer = list(model.named_parameters())
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_parameters = [
{
"params": [
p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
],
"weight_decay": TRIPPY_config.weight_decay,
},
{
"params": [
p for n, p in param_optimizer if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]
# num_train_steps = int(len(train_dataset) / TRAIN_BATCH_SIZE * EPOCHS)
num_train_steps = len(train_data_loader) * TRIPPY_config.epochs
num_warmup_steps = int(num_train_steps * TRIPPY_config.warmup_proportion)
optimizer = AdamW(optimizer_parameters, lr=TRIPPY_config.lr, eps=TRIPPY_config.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps
)
best_joint = -1
for epoch in range(TRIPPY_config.epochs):
print(f'\nEpoch: {epoch} ---------------------------------------------------------------')
print('Training the model...')
train(train_data_loader, model, optimizer, device, scheduler, n_slots, TRIPPY_config.ignore_idx, TRIPPY_config.oper2id)
if dev_path is not None:
print('Evaluating the model on dev set...')
joint_goal_acc, per_slot_acc, macro_f1_score, all_f1_score = eval(dev_raw_data, dev_data, model, device, n_slots, slot_list, label_maps, TRIPPY_config.oper2id, TRIPPY_config.multiwoz)
print(f'Joint Goal Acc: {joint_goal_acc}')
print(f'Per Slot Acc: {per_slot_acc}')
print(f'Macro F1 Score: {macro_f1_score}')
print(f'All f1 score = {all_f1_score}')
if joint_goal_acc > best_joint:
torch.save(model.state_dict(), model_path)
best_joint = joint_goal_acc
else:
torch.save(model.state_dict(), model_path)
if dev_path is not None:
print('Loading best model on dev set...')
model.load_state_dict(torch.load(model_path))
if test_path is not None:
print('Evaluating the model on test set...')
joint_goal_acc, per_slot_acc, macro_f1_score, all_f1_score = eval(test_raw_data, test_data, model, device, n_slots, slot_list, label_maps, TRIPPY_config.oper2id, TRIPPY_config.multiwoz)
print(f'Joint Goal Acc: {joint_goal_acc}')
print(f'Per Slot Acc: {per_slot_acc}')
print(f'Macro F1 Score: {macro_f1_score}')
print(f'All f1 score = {all_f1_score}')