import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F class Encoder(nn.Module): def __init__(self, emb_dim, enc_hid_dim, dec_hid_dim, dropout): super().__init__() self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True) self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim) self.dropout = nn.Dropout(dropout) def forward(self, src): """ src: src_len x batch_size x img_channel outputs: src_len x batch_size x hid_dim hidden: batch_size x hid_dim """ embedded = self.dropout(src) outputs, hidden = self.rnn(embedded) hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1))) return outputs, hidden class Attention(nn.Module): def __init__(self, enc_hid_dim, dec_hid_dim): super().__init__() self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim) self.v = nn.Linear(dec_hid_dim, 1, bias = False) def forward(self, hidden, encoder_outputs): """ hidden: batch_size x hid_dim encoder_outputs: src_len x batch_size x hid_dim, outputs: batch_size x src_len """ batch_size = encoder_outputs.shape[1] src_len = encoder_outputs.shape[0] hidden = hidden.unsqueeze(1).repeat(1, src_len, 1) encoder_outputs = encoder_outputs.permute(1, 0, 2) energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) attention = self.v(energy).squeeze(2) return F.softmax(attention, dim = 1) class Decoder(nn.Module): def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention): super().__init__() self.output_dim = output_dim self.attention = attention self.embedding = nn.Embedding(output_dim, emb_dim) self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim) self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim) self.dropout = nn.Dropout(dropout) def forward(self, input, hidden, encoder_outputs): """ inputs: batch_size hidden: batch_size x hid_dim encoder_outputs: src_len x batch_size x hid_dim """ input = input.unsqueeze(0) embedded = self.dropout(self.embedding(input)) a = self.attention(hidden, encoder_outputs) a = a.unsqueeze(1) encoder_outputs = encoder_outputs.permute(1, 0, 2) weighted = torch.bmm(a, encoder_outputs) weighted = weighted.permute(1, 0, 2) rnn_input = torch.cat((embedded, weighted), dim = 2) output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0)) assert (output == hidden).all() embedded = embedded.squeeze(0) output = output.squeeze(0) weighted = weighted.squeeze(0) prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1)) return prediction, hidden.squeeze(0), a.squeeze(1) class Seq2Seq(nn.Module): def __init__(self, vocab_size, encoder_hidden, decoder_hidden, img_channel, decoder_embedded, dropout=0.1): super().__init__() attn = Attention(encoder_hidden, decoder_hidden) self.encoder = Encoder(img_channel, encoder_hidden, decoder_hidden, dropout) self.decoder = Decoder(vocab_size, decoder_embedded, encoder_hidden, decoder_hidden, dropout, attn) def forward_encoder(self, src): """ src: timestep x batch_size x channel hidden: batch_size x hid_dim encoder_outputs: src_len x batch_size x hid_dim """ encoder_outputs, hidden = self.encoder(src) return (hidden, encoder_outputs) def forward_decoder(self, tgt, memory): """ tgt: timestep x batch_size hidden: batch_size x hid_dim encouder: src_len x batch_size x hid_dim output: batch_size x 1 x vocab_size """ tgt = tgt[-1] hidden, encoder_outputs = memory output, hidden, _ = self.decoder(tgt, hidden, encoder_outputs) output = output.unsqueeze(1) return output, (hidden, encoder_outputs) def forward(self, src, trg): """ src: time_step x batch_size trg: time_step x batch_size outputs: batch_size x time_step x vocab_size """ batch_size = src.shape[1] trg_len = trg.shape[0] trg_vocab_size = self.decoder.output_dim device = src.device outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(device) encoder_outputs, hidden = self.encoder(src) for t in range(trg_len): input = trg[t] output, hidden, _ = self.decoder(input, hidden, encoder_outputs) outputs[t] = output outputs = outputs.transpose(0, 1).contiguous() return outputs def expand_memory(self, memory, beam_size): hidden, encoder_outputs = memory hidden = hidden.repeat(beam_size, 1) encoder_outputs = encoder_outputs.repeat(1, beam_size, 1) return (hidden, encoder_outputs) def get_memory(self, memory, i): hidden, encoder_outputs = memory hidden = hidden[[i]] encoder_outputs = encoder_outputs[:, [i],:] return (hidden, encoder_outputs)