Source code for botiverse.Theorizer.generate

import torch
from botiverse.Theorizer.model.finetuned_model import MyGPT2LMHeadModel
from botiverse.Theorizer.model.dataloader import SPECIAL_TOKENS_DICT
from botiverse.Theorizer.squad.sample_data import select_with_default_sampel_probs
from transformers import GPT2Tokenizer
import os

current_file_dir = os.path.dirname(os.path.abspath(__file__))

def __prepare(context):
    sampled_infos = select_with_default_sampel_probs(context)
    instances = []
    for info in sampled_infos["selected_infos"]:
        for style in info["styles"]:
            for clue in info["clues"]:
                instances.append(
                    {
                        "paragraph": sampled_infos["context"],
                        "clue": clue.clue_text,
                        "answer": info["answer"]["answer_text"],
                        "style": style,
                    }
                )
    return instances


[docs]def generate(context, max_length=50): # Load the fine-tuned model and tokenizer model_path_or_name = os.path.join(current_file_dir,"model/pretrained-model") model = MyGPT2LMHeadModel.from_pretrained( model_path_or_name, ignore_mismatched_sizes=True ) tokenizer = GPT2Tokenizer.from_pretrained( model_path_or_name, **SPECIAL_TOKENS_DICT ) instances = __prepare(context) # print(instances) qa_dict = {} qa_dict["context"] = context qa_dict["qa"] = set() for inst in instances: paragraph = inst["paragraph"] clue = inst["clue"] answer = inst["answer"] style = inst["style"] input_sequence = ( "<sos> " + paragraph + " <clue> " + clue + " <answer> " + answer + " <style> " + style + " <question> " + style ) # Tokenize the input sequence input_ids = tokenizer.encode(input_sequence, return_tensors="pt") with torch.no_grad(): # Generate the question generated = model.generate( input_ids=input_ids, max_length=max_length, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, temperature=0.7, ) question_start_index = ( input_ids[0].tolist().index( tokenizer.convert_tokens_to_ids("<question>")) ) # Slice the generated tensor to exclude the input sequence generated_question = generated[0, question_start_index + 1:] # Decode the generated question question = tokenizer.decode( generated_question, skip_special_tokens=True) qa_dict["qa"].add((question, answer)) qa_dict["qa"] = list(qa_dict["qa"]) return qa_dict
if __name__ == "__main__": context = "Bob is eating a delicious cake in Vancouver." qa_dict = generate(context) import json print(json.dumps(qa_dict,indent=4))