"""
This module defines a GPT2 model with a custom head for language modeling.
It is adapted from the original huggin face implementation. Some functions are taken as is, as they are necessary for the model to work, they override methods from the GPT2PreTrainedModel class, like the prepare_inputs_for_generation method, __reopen_input_ids, etc.
I implemented the forward function, with the same set of parameters as the original implementaion, they used during generation.
"""
from typing import *
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import GPT2PreTrainedModel, GPT2Model
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
[docs]class MyGPT2LMHeadModel(GPT2PreTrainedModel):
_keys_to_ignore_on_load_missing = [
r"linear1.weight", r"linear2.weight", r"linear3.weight", r"lm_head.weight"]
_keys_to_ignore_on_load_unexpected = [
r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
def __init__(self, config):
super().__init__(config)
self.transformer = GPT2Model(config)
self.linear1 = nn.Linear(config.n_embd, 1024, bias=False)
self.linear2 = nn.Linear(1024, 512, bias=False)
self.linear3 = nn.Linear(512, config.n_embd, bias=False)
self.lm_head = nn.Linear(2048, config.vocab_size, bias=False)
self.post_init()
[docs] def get_output_embeddings(self):
return self.lm_head
[docs] def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
[docs] def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
hidden_states = self.linear1(hidden_states)
hidden_states = nn.functional.gelu(hidden_states)
hidden_states = self.linear2(hidden_states)
hidden_states = nn.functional.gelu(hidden_states)
hidden_states = self.linear3(hidden_states)
hidden_states = nn.functional.gelu(hidden_states)
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)
@staticmethod
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past)
for layer_past in past_key_values
)