import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F class Encoder(nn.Module): def __init__(self, emb_dim, hid_dim, n_layers, kernel_size, dropout, device, max_length = 512): super().__init__() assert kernel_size % 2 == 1, "Kernel size must be odd!" self.device = device self.scale = torch.sqrt(torch.FloatTensor([0.5])).to(device) # self.tok_embedding = nn.Embedding(input_dim, emb_dim) self.pos_embedding = nn.Embedding(max_length, emb_dim) self.emb2hid = nn.Linear(emb_dim, hid_dim) self.hid2emb = nn.Linear(hid_dim, emb_dim) self.convs = nn.ModuleList([nn.Conv1d(in_channels = hid_dim, out_channels = 2 * hid_dim, kernel_size = kernel_size, padding = (kernel_size - 1) // 2) for _ in range(n_layers)]) self.dropout = nn.Dropout(dropout) def forward(self, src): #src = [batch size, src len] src = src.transpose(0, 1) batch_size = src.shape[0] src_len = src.shape[1] device = src.device #create position tensor pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(device) #pos = [0, 1, 2, 3, ..., src len - 1] #pos = [batch size, src len] #embed tokens and positions # tok_embedded = self.tok_embedding(src) tok_embedded = src pos_embedded = self.pos_embedding(pos) #tok_embedded = pos_embedded = [batch size, src len, emb dim] #combine embeddings by elementwise summing embedded = self.dropout(tok_embedded + pos_embedded) #embedded = [batch size, src len, emb dim] #pass embedded through linear layer to convert from emb dim to hid dim conv_input = self.emb2hid(embedded) #conv_input = [batch size, src len, hid dim] #permute for convolutional layer conv_input = conv_input.permute(0, 2, 1) #conv_input = [batch size, hid dim, src len] #begin convolutional blocks... for i, conv in enumerate(self.convs): #pass through convolutional layer conved = conv(self.dropout(conv_input)) #conved = [batch size, 2 * hid dim, src len] #pass through GLU activation function conved = F.glu(conved, dim = 1) #conved = [batch size, hid dim, src len] #apply residual connection conved = (conved + conv_input) * self.scale #conved = [batch size, hid dim, src len] #set conv_input to conved for next loop iteration conv_input = conved #...end convolutional blocks #permute and convert back to emb dim conved = self.hid2emb(conved.permute(0, 2, 1)) #conved = [batch size, src len, emb dim] #elementwise sum output (conved) and input (embedded) to be used for attention combined = (conved + embedded) * self.scale #combined = [batch size, src len, emb dim] return conved, combined class Decoder(nn.Module): def __init__(self, output_dim, emb_dim, hid_dim, n_layers, kernel_size, dropout, trg_pad_idx, device, max_length = 512): super().__init__() self.kernel_size = kernel_size self.trg_pad_idx = trg_pad_idx self.device = device self.scale = torch.sqrt(torch.FloatTensor([0.5])).to(device) self.tok_embedding = nn.Embedding(output_dim, emb_dim) self.pos_embedding = nn.Embedding(max_length, emb_dim) self.emb2hid = nn.Linear(emb_dim, hid_dim) self.hid2emb = nn.Linear(hid_dim, emb_dim) self.attn_hid2emb = nn.Linear(hid_dim, emb_dim) self.attn_emb2hid = nn.Linear(emb_dim, hid_dim) self.fc_out = nn.Linear(emb_dim, output_dim) self.convs = nn.ModuleList([nn.Conv1d(in_channels = hid_dim, out_channels = 2 * hid_dim, kernel_size = kernel_size) for _ in range(n_layers)]) self.dropout = nn.Dropout(dropout) def calculate_attention(self, embedded, conved, encoder_conved, encoder_combined): #embedded = [batch size, trg len, emb dim] #conved = [batch size, hid dim, trg len] #encoder_conved = encoder_combined = [batch size, src len, emb dim] #permute and convert back to emb dim conved_emb = self.attn_hid2emb(conved.permute(0, 2, 1)) #conved_emb = [batch size, trg len, emb dim] combined = (conved_emb + embedded) * self.scale #combined = [batch size, trg len, emb dim] energy = torch.matmul(combined, encoder_conved.permute(0, 2, 1)) #energy = [batch size, trg len, src len] attention = F.softmax(energy, dim=2) #attention = [batch size, trg len, src len] attended_encoding = torch.matmul(attention, encoder_combined) #attended_encoding = [batch size, trg len, emd dim] #convert from emb dim -> hid dim attended_encoding = self.attn_emb2hid(attended_encoding) #attended_encoding = [batch size, trg len, hid dim] #apply residual connection attended_combined = (conved + attended_encoding.permute(0, 2, 1)) * self.scale #attended_combined = [batch size, hid dim, trg len] return attention, attended_combined def forward(self, trg, encoder_conved, encoder_combined): #trg = [batch size, trg len] #encoder_conved = encoder_combined = [batch size, src len, emb dim] trg = trg.transpose(0, 1) batch_size = trg.shape[0] trg_len = trg.shape[1] device = trg.device #create position tensor pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(device) #pos = [batch size, trg len] #embed tokens and positions tok_embedded = self.tok_embedding(trg) pos_embedded = self.pos_embedding(pos) #tok_embedded = [batch size, trg len, emb dim] #pos_embedded = [batch size, trg len, emb dim] #combine embeddings by elementwise summing embedded = self.dropout(tok_embedded + pos_embedded) #embedded = [batch size, trg len, emb dim] #pass embedded through linear layer to go through emb dim -> hid dim conv_input = self.emb2hid(embedded) #conv_input = [batch size, trg len, hid dim] #permute for convolutional layer conv_input = conv_input.permute(0, 2, 1) #conv_input = [batch size, hid dim, trg len] batch_size = conv_input.shape[0] hid_dim = conv_input.shape[1] for i, conv in enumerate(self.convs): #apply dropout conv_input = self.dropout(conv_input) #need to pad so decoder can't "cheat" padding = torch.zeros(batch_size, hid_dim, self.kernel_size - 1).fill_(self.trg_pad_idx).to(device) padded_conv_input = torch.cat((padding, conv_input), dim = 2) #padded_conv_input = [batch size, hid dim, trg len + kernel size - 1] #pass through convolutional layer conved = conv(padded_conv_input) #conved = [batch size, 2 * hid dim, trg len] #pass through GLU activation function conved = F.glu(conved, dim = 1) #conved = [batch size, hid dim, trg len] #calculate attention attention, conved = self.calculate_attention(embedded, conved, encoder_conved, encoder_combined) #attention = [batch size, trg len, src len] #apply residual connection conved = (conved + conv_input) * self.scale #conved = [batch size, hid dim, trg len] #set conv_input to conved for next loop iteration conv_input = conved conved = self.hid2emb(conved.permute(0, 2, 1)) #conved = [batch size, trg len, emb dim] output = self.fc_out(self.dropout(conved)) #output = [batch size, trg len, output dim] return output, attention class ConvSeq2Seq(nn.Module): def __init__(self, vocab_size, emb_dim, hid_dim, enc_layers, dec_layers, enc_kernel_size, dec_kernel_size, enc_max_length, dec_max_length, dropout, pad_idx, device): super().__init__() enc = Encoder(emb_dim, hid_dim, enc_layers, enc_kernel_size, dropout, device, enc_max_length) dec = Decoder(vocab_size, emb_dim, hid_dim, dec_layers, dec_kernel_size, dropout, pad_idx, device, dec_max_length) self.encoder = enc self.decoder = dec def forward_encoder(self, src): encoder_conved, encoder_combined = self.encoder(src) return encoder_conved, encoder_combined def forward_decoder(self, trg, memory): encoder_conved, encoder_combined = memory output, attention = self.decoder(trg, encoder_conved, encoder_combined) return output, (encoder_conved, encoder_combined) def forward(self, src, trg): #src = [batch size, src len] #trg = [batch size, trg len - 1] ( token sliced off the end) #calculate z^u (encoder_conved) and (z^u + e) (encoder_combined) #encoder_conved is output from final encoder conv. block #encoder_combined is encoder_conved plus (elementwise) src embedding plus # positional embeddings encoder_conved, encoder_combined = self.encoder(src) #encoder_conved = [batch size, src len, emb dim] #encoder_combined = [batch size, src len, emb dim] #calculate predictions of next words #output is a batch of predictions for each word in the trg sentence #attention a batch of attention scores across the src sentence for # each word in the trg sentence output, attention = self.decoder(trg, encoder_conved, encoder_combined) #output = [batch size, trg len - 1, output dim] #attention = [batch size, trg len - 1, src len] return output#, attention