import torch import torch.nn as nn import torch.nn.functional as F import math class PositionalEmbedding(torch.nn.Module): def __init__(self, d_model, max_len=128): super().__init__() # Compute the positional encodings once in log space. pe = torch.zeros(max_len, d_model).float() pe.require_grad = False for pos in range(max_len): # for each dimension of the each position for i in range(0, d_model, 2): pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model))) pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model))) # include the batch size self.pe = pe.unsqueeze(0) # self.register_buffer('pe', pe) def forward(self, x): return self.pe class BERTEmbedding(torch.nn.Module): """ BERT Embedding which is consisted with under features 1. TokenEmbedding : normal embedding matrix 2. PositionalEmbedding : adding positional information using sin, cos 2. SegmentEmbedding : adding sentence segment info, (sent_A:1, sent_B:2) sum of all these features are output of BERTEmbedding """ def __init__(self, vocab_size, embed_size, seq_len=64, dropout=0.1): """ :param vocab_size: total vocab size :param embed_size: embedding size of token embedding :param dropout: dropout rate """ super().__init__() self.embed_size = embed_size # (m, seq_len) --> (m, seq_len, embed_size) # padding_idx is not updated during training, remains as fixed pad (0) self.token = torch.nn.Embedding(vocab_size, embed_size, padding_idx=0) self.segment = torch.nn.Embedding(3, embed_size, padding_idx=0) self.position = PositionalEmbedding(d_model=embed_size, max_len=seq_len) self.dropout = torch.nn.Dropout(p=dropout) def forward(self, sequence, segment_label): x = self.token(sequence) + self.position(sequence) + self.segment(segment_label) return self.dropout(x) ### attention layers class MultiHeadedAttention(torch.nn.Module): def __init__(self, heads, d_model, dropout=0.1): super(MultiHeadedAttention, self).__init__() assert d_model % heads == 0 self.d_k = d_model // heads self.heads = heads self.dropout = torch.nn.Dropout(dropout) self.query = torch.nn.Linear(d_model, d_model) self.key = torch.nn.Linear(d_model, d_model) self.value = torch.nn.Linear(d_model, d_model) self.output_linear = torch.nn.Linear(d_model, d_model) def forward(self, query, key, value, mask): """ query, key, value of shape: (batch_size, max_len, d_model) mask of shape: (batch_size, 1, 1, max_words) """ # (batch_size, max_len, d_model) query = self.query(query) key = self.key(key) value = self.value(value) # (batch_size, max_len, d_model) --> (batch_size, max_len, h, d_k) --> (batch_size, h, max_len, d_k) query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3) key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3) value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3) # (batch_size, h, max_len, d_k) matmul (batch_size, h, d_k, max_len) --> (batch_size, h, max_len, max_len) scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / math.sqrt(query.size(-1)) # fill 0 mask with super small number so it wont affect the softmax weight # (batch_size, h, max_len, max_len) scores = scores.masked_fill(mask == 0, -1e9) # (batch_size, h, max_len, max_len) # softmax to put attention weight for all non-pad tokens # max_len X max_len matrix of attention weights = F.softmax(scores, dim=-1) weights = self.dropout(weights) # (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k) context = torch.matmul(weights, value) # (batch_size, h, max_len, d_k) --> (batch_size, max_len, h, d_k) --> (batch_size, max_len, d_model) context = context.permute(0, 2, 1, 3).contiguous().view(context.shape[0], -1, self.heads * self.d_k) # (batch_size, max_len, d_model) return self.output_linear(context) class FeedForward(torch.nn.Module): "Implements FFN equation." def __init__(self, d_model, middle_dim=2048, dropout=0.1): super(FeedForward, self).__init__() self.fc1 = torch.nn.Linear(d_model, middle_dim) self.fc2 = torch.nn.Linear(middle_dim, d_model) self.dropout = torch.nn.Dropout(dropout) self.activation = torch.nn.GELU() def forward(self, x): out = self.activation(self.fc1(x)) out = self.fc2(self.dropout(out)) return out class EncoderLayer(torch.nn.Module): def __init__( self, d_model=768, heads=12, feed_forward_hidden=768 * 4, dropout=0.1 ): super(EncoderLayer, self).__init__() self.layernorm = torch.nn.LayerNorm(d_model) self.self_multihead = MultiHeadedAttention(heads, d_model) self.feed_forward = FeedForward(d_model, middle_dim=feed_forward_hidden) self.dropout = torch.nn.Dropout(dropout) def forward(self, embeddings, mask): # embeddings: (batch_size, max_len, d_model) # encoder mask: (batch_size, 1, 1, max_len) # result: (batch_size, max_len, d_model) interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask)) # residual layer interacted = self.layernorm(interacted + embeddings) # bottleneck feed_forward_out = self.dropout(self.feed_forward(interacted)) encoded = self.layernorm(feed_forward_out + interacted) return encoded class BERT(torch.nn.Module): """ BERT model : Bidirectional Encoder Representations from Transformers. """ def __init__(self, vocab_size, d_model=768, n_layers=12, heads=12, dropout=0.1): """ :param vocab_size: vocab_size of total words :param hidden: BERT model hidden size :param n_layers: numbers of Transformer blocks(layers) :param attn_heads: number of attention heads :param dropout: dropout rate """ super().__init__() self.d_model = d_model self.n_layers = n_layers self.heads = heads # paper noted they used 4 * hidden_size for ff_network_hidden_size self.feed_forward_hidden = d_model * 4 # embedding for BERT, sum of positional, segment, token embeddings self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=d_model) # multi-layers transformer blocks, deep network self.encoder_blocks = torch.nn.ModuleList( [EncoderLayer(d_model, heads, d_model * 4, dropout) for _ in range(n_layers)]) def forward(self, x, segment_info): # attention masking for padded token # (batch_size, 1, seq_len, seq_len) mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1) # embedding the indexed sequence to sequence of vectors x = self.embedding(x, segment_info) # running over multiple transformer blocks for encoder in self.encoder_blocks: x = encoder.forward(x, mask) return x class NextSentencePrediction(torch.nn.Module): """ 2-class classification model : is_next, is_not_next """ def __init__(self, hidden): """ :param hidden: BERT model output size """ super().__init__() self.linear = torch.nn.Linear(hidden, 2) self.softmax = torch.nn.LogSoftmax(dim=-1) def forward(self, x): # use only the first token which is the [CLS] return self.softmax(self.linear(x[:, 0])) class MaskedLanguageModel(torch.nn.Module): """ predicting origin token from masked input sequence n-class classification problem, n-class = vocab_size """ def __init__(self, hidden, vocab_size): """ :param hidden: output size of BERT model :param vocab_size: total vocab size """ super().__init__() self.linear = torch.nn.Linear(hidden, vocab_size) self.softmax = torch.nn.LogSoftmax(dim=-1) def forward(self, x): return self.softmax(self.linear(x)) class BERTLM(torch.nn.Module): """ BERT Language Model Next Sentence Prediction Model + Masked Language Model """ def __init__(self, bert: BERT, vocab_size): """ :param bert: BERT model which should be trained :param vocab_size: total vocab size for masked_lm """ super().__init__() self.bert = bert self.next_sentence = NextSentencePrediction(self.bert.d_model) self.mask_lm = MaskedLanguageModel(self.bert.d_model, vocab_size) def forward(self, x, segment_label): x = self.bert(x, segment_label) return self.next_sentence(x), self.mask_lm(x)