import torch import torch.nn as nn import math from torch import Tensor from torch.nn import Transformer # Define special symbols and indices UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3 # Make sure the tokens are in order of their indices to properly # insert them in vocab UNK, PAD, BOS, EOS = "", "", "", "" SPECIAL_SYMBOLS = [UNK, PAD, BOS, EOS] # helper Module that adds positional encoding to the # token embedding to introduce a notion of word order. class PositionalEncoding(nn.Module): def __init__(self, emb_size: int, dropout: float, maxlen: int = 5000): super(PositionalEncoding, self).__init__() den = torch.exp( -torch.arange(0, emb_size, 2) * math.log(10000) / emb_size ) pos = torch.arange(0, maxlen).reshape(maxlen, 1) pos_embedding = torch.zeros((maxlen, emb_size)) pos_embedding[:, 0::2] = torch.sin(pos * den) pos_embedding[:, 1::2] = torch.cos(pos * den) pos_embedding = pos_embedding.unsqueeze(-2) self.dropout = nn.Dropout(dropout) self.register_buffer('pos_embedding', pos_embedding) def forward(self, token_embedding: Tensor): return self.dropout( token_embedding + self.pos_embedding[:token_embedding.size(0), :] ) # helper Module to convert tensor of input indices into # corresponding tensor of token embeddings class TokenEmbedding(nn.Module): def __init__(self, vocab_size: int, emb_size): super(TokenEmbedding, self).__init__() self.embedding = nn.Embedding(vocab_size, emb_size) self.emb_size = emb_size def forward(self, tokens: Tensor): return self.embedding(tokens.long()) * math.sqrt(self.emb_size) # Seq2Seq Network class Seq2SeqTransformer(nn.Module): def __init__( self, src_vocab_size: int, tgt_vocab_size: int, num_encoder_layers: int = 3, num_decoder_layers: int = 3, emb_size: int = 512, nhead: int = 8, dim_feedforward: int = 512, dropout: float = 0.1 ): super(Seq2SeqTransformer, self).__init__() self.transformer = Transformer( d_model=emb_size, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward, dropout=dropout ) self.generator = nn.Linear(emb_size, tgt_vocab_size) self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size) self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size) self.positional_encoding = PositionalEncoding( emb_size, dropout=dropout ) self._init() def _init(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward( self, src: Tensor, trg: Tensor, src_mask: Tensor, tgt_mask: Tensor, src_padding_mask: Tensor, tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor ): # shape: [seq_len, batch_size, emb_size] src_emb = self.positional_encoding(self.src_tok_emb(src)) tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg)) outs = self.transformer( src_emb, tgt_emb, src_mask, tgt_mask, None, src_padding_mask, tgt_padding_mask, memory_key_padding_mask ) return self.generator(outs) def encode(self, src: Tensor, src_mask: Tensor): return self.transformer.encoder( self.positional_encoding(self.src_tok_emb(src)), src_mask ) def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor): return self.transformer.decoder( self.positional_encoding(self.tgt_tok_emb(tgt)), memory, tgt_mask ) def generate_square_subsequent_mask(sz, device): mask = ( torch.triu(torch.ones((sz, sz), device=device)) == 1 ).transpose(0, 1) mask = mask.float().masked_fill( mask == 0, float('-inf') ).masked_fill(mask == 1, float(0.0)) return mask def create_tgt_mask(tgt, device): tgt_seq_len = tgt.shape[0] tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device) tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1) return tgt_mask, tgt_padding_mask def create_mask(src, tgt, device): src_seq_len = src.shape[0] tgt_seq_len = tgt.shape[0] tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device) src_mask = torch.zeros( (src_seq_len, src_seq_len), device=device ).type(torch.bool) src_padding_mask = (src == PAD_IDX).transpose(0, 1) tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1) return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask