Source code for botiverse.models.TRIPPY.infer

"""
This Module has the inference functions for TRIPPY.
"""


import torch

from botiverse.models.TRIPPY.data import create_inputs
from botiverse.models.TRIPPY.utils import create_span_output


[docs]def infer(model, slot_list, current_state, history, sys_utter, user_utter, inform_mem, device, oper2id, tokenizer, max_len): """ Infer the dialogue state using the TRIPPY model. :param model: The TRIPPY model for inference. :type model: TRIPPY :param slot_list: The list of slots. :type slot_list: list :param current_state: The current dialogue state. :type current_state: dict :param history: The dialogue history. :type history: list :param sys_utter: The system's utterance. :type sys_utter: str :param user_utter: The user's utterance. :type user_utter: str :param inform_mem: The inform memory. :type inform_mem: dict :param device: The device to run the inference on. :type device: torch.device :param oper2id: The mapping of operations to IDs. :type oper2id: dict :param tokenizer: The tokenizer to tokenize the input. :type tokenizer: transformers.PreTrainedTokenizer :param max_len: The maximum length of the input sequence. :type max_len: int :return: The predicted dialogue state. :rtype: dict """ model.eval() # turn data to inputs input, ids, mask, token_type_ids, tok_input_offsets, input_tokens, padding_len = create_inputs(history, user_utter, sys_utter, tokenizer, max_len) # print(input, ids, mask, token_type_ids, tok_input_offsets, input_tokens, padding_len) with torch.no_grad(): n_slots = len(slot_list) ids = torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(device) mask = torch.tensor(mask, dtype=torch.long).unsqueeze(0).to(device) token_type_ids = torch.tensor(token_type_ids, dtype=torch.long).unsqueeze(0).to(device) inform_aux_features = torch.zeros((1, n_slots)).to(device) ds_aux_features = torch.zeros((1, n_slots)).to(device) input_tokens = ' '.join(input_tokens) padding_len = padding_len for slot_idx, slot in enumerate(slot_list): if slot in inform_mem: inform_aux_features[0, slot_idx] = 1 if slot in current_state: ds_aux_features[0, slot_idx] = 1 # print(slot_list) # print(inform_aux_features) # print(inform_mem) # get model outputs slots_start_logits, slots_end_logits, slots_oper_logits, slots_refer_logits = model(ids=ids, mask=mask, token_type_ids=token_type_ids, inform_aux_features=inform_aux_features, ds_aux_features=ds_aux_features) # update the predicted state of each slot pred_state = current_state.copy() for slot_idx, slot in enumerate(slot_list): # get the predicted operation pred_oper = slots_oper_logits[slot_idx][0].argmax(dim=-1).item() # print(slot, torch.softmax(slots_oper_logits[slot_idx][0], dim=-1)) # update the slot based on the operation if pred_oper == oper2id['carryover']: # carryover continue elif pred_oper == oper2id['dontcare']: # dontcare pred_state[slot] = 'dontcare' elif pred_oper == oper2id['update']: # update pred_state[slot] = create_span_output(slots_start_logits[slot_idx][0].cpu().detach().numpy(), slots_end_logits[slot_idx][0].cpu().detach().numpy(), padding_len, input_tokens) elif pred_oper == oper2id['refer']: # refer refered_slot = slots_refer_logits[slot_idx][0].argmax(dim=-1).item() if refered_slot != n_slots and slot_list[refered_slot] in current_state: pred_state[slot] = current_state[slot_list[refered_slot]] elif pred_oper == oper2id['yes']: # yes pred_state[slot] = 'yes' elif pred_oper == oper2id['no']: # no pred_state[slot] = 'no' elif pred_oper == oper2id['inform']: # inform if slot in inform_mem: pred_state[slot] = '§§' + inform_mem[slot][0] return pred_state