Source code for botiverse.Theorizer.model.utils

[docs]def get_overlap_position(para_ids, ans_ids, ans_prefix_ids): """ Get the position (start and end indices) of the overlapping region between a paragraph and an answer after the answer prefix. Args: para_ids (list): The paragraph token IDs. ans_ids (list): The answer token IDs. ans_prefix_ids (list): The prefix token IDs of the answer. Returns: tuple: A tuple representing the start and end indices of the overlapping region. """ # Find the first index where the paragraph and answer prefix differ for i, (para_id, ans_prefix_id) in enumerate(zip(para_ids, ans_prefix_ids)): if para_id != ans_prefix_id: first_diff_index = i break else: first_diff_index = min(len(ans_prefix_ids), len(para_ids)) # Calculate the end index of the overlapping region overlap_end_index = min(first_diff_index + len(ans_ids), len(para_ids)) return (first_diff_index, overlap_end_index)
[docs]def pad_dataset(dataset, padding=0): """Pad the dataset. This could be optimized by defining a Dataset class and padd only batches but this is simpler.""" max_l = max(len(x) for x in dataset["input_ids"]) for name in MODEL_INPUTS: dataset[name] = [ x + [padding if name != "lm_labels" else -100] * (max_l - len(x)) for x in dataset[name] ] return dataset