Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
class Encoder(nn.Module): | |
def __init__(self, emb_dim, enc_hid_dim, dec_hid_dim, dropout): | |
super().__init__() | |
self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True) | |
self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, src): | |
""" | |
src: src_len x batch_size x img_channel | |
outputs: src_len x batch_size x hid_dim | |
hidden: batch_size x hid_dim | |
""" | |
embedded = self.dropout(src) | |
outputs, hidden = self.rnn(embedded) | |
hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1))) | |
return outputs, hidden | |
class Attention(nn.Module): | |
def __init__(self, enc_hid_dim, dec_hid_dim): | |
super().__init__() | |
self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim) | |
self.v = nn.Linear(dec_hid_dim, 1, bias = False) | |
def forward(self, hidden, encoder_outputs): | |
""" | |
hidden: batch_size x hid_dim | |
encoder_outputs: src_len x batch_size x hid_dim, | |
outputs: batch_size x src_len | |
""" | |
batch_size = encoder_outputs.shape[1] | |
src_len = encoder_outputs.shape[0] | |
hidden = hidden.unsqueeze(1).repeat(1, src_len, 1) | |
encoder_outputs = encoder_outputs.permute(1, 0, 2) | |
energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) | |
attention = self.v(energy).squeeze(2) | |
return F.softmax(attention, dim = 1) | |
class Decoder(nn.Module): | |
def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention): | |
super().__init__() | |
self.output_dim = output_dim | |
self.attention = attention | |
self.embedding = nn.Embedding(output_dim, emb_dim) | |
self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim) | |
self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, input, hidden, encoder_outputs): | |
""" | |
inputs: batch_size | |
hidden: batch_size x hid_dim | |
encoder_outputs: src_len x batch_size x hid_dim | |
""" | |
input = input.unsqueeze(0) | |
embedded = self.dropout(self.embedding(input)) | |
a = self.attention(hidden, encoder_outputs) | |
a = a.unsqueeze(1) | |
encoder_outputs = encoder_outputs.permute(1, 0, 2) | |
weighted = torch.bmm(a, encoder_outputs) | |
weighted = weighted.permute(1, 0, 2) | |
rnn_input = torch.cat((embedded, weighted), dim = 2) | |
output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0)) | |
assert (output == hidden).all() | |
embedded = embedded.squeeze(0) | |
output = output.squeeze(0) | |
weighted = weighted.squeeze(0) | |
prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1)) | |
return prediction, hidden.squeeze(0), a.squeeze(1) | |
class Seq2Seq(nn.Module): | |
def __init__(self, vocab_size, encoder_hidden, decoder_hidden, img_channel, decoder_embedded, dropout=0.1): | |
super().__init__() | |
attn = Attention(encoder_hidden, decoder_hidden) | |
self.encoder = Encoder(img_channel, encoder_hidden, decoder_hidden, dropout) | |
self.decoder = Decoder(vocab_size, decoder_embedded, encoder_hidden, decoder_hidden, dropout, attn) | |
def forward_encoder(self, src): | |
""" | |
src: timestep x batch_size x channel | |
hidden: batch_size x hid_dim | |
encoder_outputs: src_len x batch_size x hid_dim | |
""" | |
encoder_outputs, hidden = self.encoder(src) | |
return (hidden, encoder_outputs) | |
def forward_decoder(self, tgt, memory): | |
""" | |
tgt: timestep x batch_size | |
hidden: batch_size x hid_dim | |
encouder: src_len x batch_size x hid_dim | |
output: batch_size x 1 x vocab_size | |
""" | |
tgt = tgt[-1] | |
hidden, encoder_outputs = memory | |
output, hidden, _ = self.decoder(tgt, hidden, encoder_outputs) | |
output = output.unsqueeze(1) | |
return output, (hidden, encoder_outputs) | |
def forward(self, src, trg): | |
""" | |
src: time_step x batch_size | |
trg: time_step x batch_size | |
outputs: batch_size x time_step x vocab_size | |
""" | |
batch_size = src.shape[1] | |
trg_len = trg.shape[0] | |
trg_vocab_size = self.decoder.output_dim | |
device = src.device | |
outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(device) | |
encoder_outputs, hidden = self.encoder(src) | |
for t in range(trg_len): | |
input = trg[t] | |
output, hidden, _ = self.decoder(input, hidden, encoder_outputs) | |
outputs[t] = output | |
outputs = outputs.transpose(0, 1).contiguous() | |
return outputs | |
def expand_memory(self, memory, beam_size): | |
hidden, encoder_outputs = memory | |
hidden = hidden.repeat(beam_size, 1) | |
encoder_outputs = encoder_outputs.repeat(1, beam_size, 1) | |
return (hidden, encoder_outputs) | |
def get_memory(self, memory, i): | |
hidden, encoder_outputs = memory | |
hidden = hidden[[i]] | |
encoder_outputs = encoder_outputs[:, [i],:] | |
return (hidden, encoder_outputs) | |