"""
Theis Module contains the BERT model architecture.
"""
import torch
import torch.nn as nn
# Embeddings
# 1. Cluster similar words together.
# 2. Preserve different relationships between words such as: semantic, syntactic, linear,
# and since BERT is bidirectional it will also preserve contextual relationships as well.
[docs]class Embeddings(nn.Module):
"""
Embedding layer for BERT.
This layer takes input_ids and token_type_ids as inputs and generates word embeddings
using three types of embeddings: word, position, and token_type embeddings.
:param config: BERT configuration.
:type config: Config
"""
def __init__(self, config):
super(Embeddings, self).__init__()
# Bert uses 3 types of embeddings: word, position, and token_type (segment type).
# LayerNorm is used to normalize the sum of the embeddings.
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.padding_idx)
self.position_embeddings = nn.Embedding(config.max_seq, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.token_types, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.dropout)
[docs] def forward(self, input_ids, token_type_ids): # input_ids: [batch_size, seq_len] token_type_ids: [batch_size, seq_len]
"""
Forward pass of the Embeddings layer.
:param input_ids: The input token IDs.
:type input_ids: torch.Tensor
:param token_type_ids: The token type IDs.
:type token_type_ids: torch.Tensor
:return: The generated embeddings.
:rtype: torch.Tensor
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seq_len = input_ids.size(1)
position_ids = torch.arange(seq_len).unsqueeze(0).expand_as(input_ids).to(device) # position_ids: [batch_size, seq_len]
word_embeddings = self.word_embeddings(input_ids) # word_embeddings: [batch_size, seq_len, hidden_size]
position_embeddings = self.position_embeddings(position_ids) # position_embeddings: [batch_size, seq_len, hidden_size]
token_type_embeddings = self.token_type_embeddings(token_type_ids) # token_type_embeddings: [batch_size, seq_len, hidden_size]
embeddings = word_embeddings + position_embeddings + token_type_embeddings # embeddings: [batch_size, seq_len, hidden_size]
# Normalize by subtracting the mean and dividing by the standard deviation calculated across the feature dimension
# then multiply by a learned gain parameter and add to a learned bias parameter.
embeddings = self.layer_norm(embeddings) # embeddings: [batch_size, seq_len, hidden_size]
embeddings = self.dropout(embeddings) # embeddings: [batch_size, seq_len, hidden_size]
return embeddings
# Encoder layer
[docs]class EncoderLayer(nn.Module):
"""
Encoder layer for BERT.
This layer contains self-attention, layer normalization, and position-wise feed-forward network.
:param config: BERT configuration.
:type config: Config
"""
def __init__(self, config):
super(EncoderLayer, self).__init__()
self.self_attention = MultiHeadAttention(config)
self.self_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.self_dropout = nn.Dropout(config.dropout)
self.position_wise_feed_forward = PositionWiseFeedForward(config)
self.ffn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.ffn_dropout = nn.Dropout(config.dropout)
[docs] def forward(self, input, attention_mask):
"""
Forward pass of the EncoderLayer.
:param input: The input tensor.
:type input: torch.Tensor
:param attention_mask: The attention mask.
:type attention_mask: torch.Tensor
:return: The output tensor.
:rtype: torch.Tensor
"""
# Multi-head attention
context, attention = self.self_attention(input, input, input, attention_mask) # context: [batch_size, seq_len, hidden_size] attention: [batch_size, heads, seq_len, seq_len]
# Add and normalize
context = self.self_dropout(context) # context: [batch_size, seq_len, hidden_size]
output = self.self_layer_norm(input + context) # output: [batch_size, seq_len, hidden_size]
# Position-wise feed-forward network
context = self.position_wise_feed_forward(output) # context: [batch_size, seq_len, hidden_size]
# Add and normalize
context = self.ffn_dropout(context) # context: [batch_size, seq_len, hidden_size]
output = self.ffn_layer_norm(output + context) # output: [batch_size, seq_len, hidden_size]
return output, attention
# Multi-head attention
[docs]class MultiHeadAttention(nn.Module):
"""
Multi-head attention layer for BERT.
This layer performs multi-head self-attention and returns the output context.
:param config: BERT configuration.
:type config: Config
"""
def __init__(self, config):
super(MultiHeadAttention, self).__init__()
self.config = config
self.w_q = nn.Linear(config.hidden_size, config.hidden_size)
self.w_k = nn.Linear(config.hidden_size, config.hidden_size)
self.w_v = nn.Linear(config.hidden_size, config.hidden_size)
self.w_o = nn.Linear(config.hidden_size, config.hidden_size)
self.softmax = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(config.dropout)
[docs] def forward(self, query, key, value, attention_mask):
"""
Forward pass of the MultiHeadAttention.
:param query: The query tensor.
:type query: torch.Tensor
:param key: The key tensor.
:type key: torch.Tensor
:param value: The value tensor.
:type value: torch.Tensor
:param attention_mask: The attention mask.
:type attention_mask: torch.Tensor
:return: The output context.
:rtype: torch.Tensor
"""
# query: [batch_size, seq_len, hidden_size] key: [batch_size, seq_len, hidden_size]
# value: [batch_size, seq_len, hidden_size] attention_mask: [batch_size, seq_len_q, seq_len_k]
batch_size, seq_len, hidden_size = query.size()
query = self.w_q(query).view(batch_size, seq_len, self.config.heads, hidden_size // self.config.heads).transpose(1, 2) # query: [batch_size, heads, seq_len, hidden_size // heads]
key = self.w_k(key).view(batch_size, seq_len, self.config.heads, hidden_size // self.config.heads).transpose(1, 2) # key: [batch_size, heads, seq_len, hidden_size // heads]
value = self.w_v(value).view(batch_size, seq_len, self.config.heads, hidden_size // self.config.heads).transpose(1, 2) # value: [batch_size, heads, seq_len, hidden_size // heads]
# Scaled dot-product attention
attention = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(float(hidden_size // self.config.heads))) # attention: [batch_size, heads, seq_len, seq_len]
attention_mask = attention_mask.unsqueeze(1).repeat(1, self.config.heads, 1, 1) # attention_mask: [batch_size, heads, seq_len_q, seq_len_k]
attention_mask = (attention_mask == 0)
attention.masked_fill_(attention_mask, -1e9) # attention: [batch_size, heads, seq_len, seq_len]
attention = self.softmax(attention) # attention: [batch_size, heads, seq_len, seq_len]
attention = self.dropout(attention) # attention: [batch_size, heads, seq_len, seq_len]
context = torch.matmul(attention, value) # context: [batch_size, heads, seq_len, hidden_size // heads]
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size) # context: [batch_size, seq_len, hidden_size]
output = self.w_o(context) # output: [batch_size, seq_len, hidden_size]
return output, attention
# Position-wise feed-forward network
[docs]class PositionWiseFeedForward(nn.Module):
"""
Position-wise feed-forward network layer for BERT.
This layer applies two linear transformations with a GELU activation function.
:param config: BERT configuration.
:type config: Config
"""
def __init__(self, config):
super(PositionWiseFeedForward, self).__init__()
self.linear1 = nn.Linear(config.hidden_size, config.ff_size)
self.linear2 = nn.Linear(config.ff_size, config.hidden_size)
self.gelu = nn.GELU()
[docs] def forward(self, input):
"""
Forward pass of the PositionWiseFeedForward layer.
:param input: The input tensor.
:type input: torch.Tensor
:return: The output tensor.
:rtype: torch.Tensor
"""
output = self.linear1(input) # output: [batch_size, seq_len, ff_size]
output = self.gelu(output) # output: [batch_size, seq_len, ff_size]
output = self.linear2(output) # output: [batch_size, seq_len, hidden_size]
return output
# Bert
# 1. Puts it all together.
[docs]class Bert(nn.Module):
"""
BERT model implementation.
This model combines the Embeddings layer, EncoderLayers, and linear transformation layers
to perform BERT-based processing.
:param config: BERT configuration.
:type config: Config
"""
def __init__(self, config):
super(Bert, self).__init__()
self.embeddings = Embeddings(config)
self.encoder = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
self.linear = nn.Linear(config.hidden_size, config.hidden_size)
self.tanh = nn.Tanh()
[docs] def forward(self, input_ids, token_type_ids, attention_mask, return_dict=False): # input_ids: [batch_size, seq_len] token_type_ids: [batch_size, seq_len] attention_mask: [batch_size, seq_len]
"""
Forward pass of the Bert model.
:param input_ids: The input token IDs.
:type input_ids: torch.Tensor
:param token_type_ids: The token type IDs.
:type token_type_ids: torch.Tensor
:param attention_mask: The attention mask.
:type attention_mask: torch.Tensor
:param return_dict: Whether to return a dictionary or not, defaults to False.
:type return_dict: bool
:return: The sequence output and pooled output.
:rtype: torch.Tensor, torch.Tensor
"""
# Embedding
output = self.embeddings(input_ids, token_type_ids) # output: [batch_size, seq_len, hidden_size]
# Encoder
attention_mask = attention_mask.unsqueeze(1).repeat(1, output.size(1), 1) # attention_mask: [batch_size, seq_len, seq_len]
for encoder_layer in self.encoder:
output, attention = encoder_layer(output, attention_mask) # output: [batch_size, seq_len, hidden_size] attention: [batch_size, heads, seq_len, seq_len]
# Sequnce and pooled outputs
sequence_output = output # sequence_output: [batch_size, seq_len, hidden_size]
pooled_output = self.tanh(self.linear(sequence_output[:, 0])) # pooled_output: [batch_size, hidden_size]
return sequence_output, pooled_output