import math import torch import torch.nn as nn from torch import Tensor from torch.nn import TransformerEncoder, TransformerDecoder, \ TransformerEncoderLayer, TransformerDecoderLayer torch.manual_seed(0) class PositionalEncoding(nn.Module): def __init__(self, emb_size: int, dropout, maxlen: int = 5000): super(PositionalEncoding, self).__init__() den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size) pos = torch.arange(0, maxlen).reshape(maxlen, 1) pos_embedding = torch.zeros((maxlen, emb_size)) pos_embedding[:, 0::2] = torch.sin(pos * den) pos_embedding[:, 1::2] = torch.cos(pos * den) pos_embedding = pos_embedding.unsqueeze(-2) self.dropout = nn.Dropout(dropout) self.register_buffer('pos_embedding', pos_embedding) def forward(self, token_embedding: Tensor): return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0),:]) class TokenEmbedding(nn.Module): def __init__(self, vocab_size: int, emb_size): super(TokenEmbedding, self).__init__() self.embedding = nn.Embedding(vocab_size, emb_size) self.emb_size = emb_size def forward(self, tokens: Tensor): return self.embedding(tokens.long()) * math.sqrt(self.emb_size) class BanglaTransformer(nn.Module): def __init__(self, num_encoder_layers: int, num_decoder_layers: int, emb_size: int, src_vocab_size: int, tgt_vocab_size: int, dim_feedforward:int = 512, dropout:float = 0.1, nhead:int=8): super(BanglaTransformer, self).__init__() encoder_layer = TransformerEncoderLayer( d_model=emb_size, nhead=nhead, dim_feedforward=dim_feedforward ) self.transformer_encoder = TransformerEncoder( encoder_layer, num_layers=num_encoder_layers ) decoder_layer = TransformerDecoderLayer( d_model=emb_size, nhead=nhead, dim_feedforward=dim_feedforward ) self.transformer_decoder = TransformerDecoder( decoder_layer, num_layers=num_decoder_layers ) self.generator = nn.Linear(emb_size, tgt_vocab_size) self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size) self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size) self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout) def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor, tgt_mask: Tensor, src_padding_mask: Tensor, tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor): src_emb = self.positional_encoding(self.src_tok_emb(src)) tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg)) memory = self.transformer_encoder(src_emb, src_mask, src_padding_mask) outs = self.transformer_decoder(tgt_emb, memory, tgt_mask, None, tgt_padding_mask, memory_key_padding_mask) return self.generator(outs) def encode(self, src: Tensor, src_mask: Tensor): return self.transformer_encoder(self.positional_encoding( self.src_tok_emb(src)), src_mask) def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor): return self.transformer_decoder(self.positional_encoding( self.tgt_tok_emb(tgt)), memory, tgt_mask)