hantech's picture
Duplicate from hantech/VietOCR
33c0fae
raw
history blame
No virus
5.91 kB
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
super().__init__()
self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True)
self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, src):
"""
src: src_len x batch_size x img_channel
outputs: src_len x batch_size x hid_dim
hidden: batch_size x hid_dim
"""
embedded = self.dropout(src)
outputs, hidden = self.rnn(embedded)
hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))
return outputs, hidden
class Attention(nn.Module):
def __init__(self, enc_hid_dim, dec_hid_dim):
super().__init__()
self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
self.v = nn.Linear(dec_hid_dim, 1, bias = False)
def forward(self, hidden, encoder_outputs):
"""
hidden: batch_size x hid_dim
encoder_outputs: src_len x batch_size x hid_dim,
outputs: batch_size x src_len
"""
batch_size = encoder_outputs.shape[1]
src_len = encoder_outputs.shape[0]
hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
encoder_outputs = encoder_outputs.permute(1, 0, 2)
energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2)))
attention = self.v(energy).squeeze(2)
return F.softmax(attention, dim = 1)
class Decoder(nn.Module):
def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
super().__init__()
self.output_dim = output_dim
self.attention = attention
self.embedding = nn.Embedding(output_dim, emb_dim)
self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, input, hidden, encoder_outputs):
"""
inputs: batch_size
hidden: batch_size x hid_dim
encoder_outputs: src_len x batch_size x hid_dim
"""
input = input.unsqueeze(0)
embedded = self.dropout(self.embedding(input))
a = self.attention(hidden, encoder_outputs)
a = a.unsqueeze(1)
encoder_outputs = encoder_outputs.permute(1, 0, 2)
weighted = torch.bmm(a, encoder_outputs)
weighted = weighted.permute(1, 0, 2)
rnn_input = torch.cat((embedded, weighted), dim = 2)
output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
assert (output == hidden).all()
embedded = embedded.squeeze(0)
output = output.squeeze(0)
weighted = weighted.squeeze(0)
prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))
return prediction, hidden.squeeze(0), a.squeeze(1)
class Seq2Seq(nn.Module):
def __init__(self, vocab_size, encoder_hidden, decoder_hidden, img_channel, decoder_embedded, dropout=0.1):
super().__init__()
attn = Attention(encoder_hidden, decoder_hidden)
self.encoder = Encoder(img_channel, encoder_hidden, decoder_hidden, dropout)
self.decoder = Decoder(vocab_size, decoder_embedded, encoder_hidden, decoder_hidden, dropout, attn)
def forward_encoder(self, src):
"""
src: timestep x batch_size x channel
hidden: batch_size x hid_dim
encoder_outputs: src_len x batch_size x hid_dim
"""
encoder_outputs, hidden = self.encoder(src)
return (hidden, encoder_outputs)
def forward_decoder(self, tgt, memory):
"""
tgt: timestep x batch_size
hidden: batch_size x hid_dim
encouder: src_len x batch_size x hid_dim
output: batch_size x 1 x vocab_size
"""
tgt = tgt[-1]
hidden, encoder_outputs = memory
output, hidden, _ = self.decoder(tgt, hidden, encoder_outputs)
output = output.unsqueeze(1)
return output, (hidden, encoder_outputs)
def forward(self, src, trg):
"""
src: time_step x batch_size
trg: time_step x batch_size
outputs: batch_size x time_step x vocab_size
"""
batch_size = src.shape[1]
trg_len = trg.shape[0]
trg_vocab_size = self.decoder.output_dim
device = src.device
outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(device)
encoder_outputs, hidden = self.encoder(src)
for t in range(trg_len):
input = trg[t]
output, hidden, _ = self.decoder(input, hidden, encoder_outputs)
outputs[t] = output
outputs = outputs.transpose(0, 1).contiguous()
return outputs
def expand_memory(self, memory, beam_size):
hidden, encoder_outputs = memory
hidden = hidden.repeat(beam_size, 1)
encoder_outputs = encoder_outputs.repeat(1, beam_size, 1)
return (hidden, encoder_outputs)
def get_memory(self, memory, i):
hidden, encoder_outputs = memory
hidden = hidden[[i]]
encoder_outputs = encoder_outputs[:, [i],:]
return (hidden, encoder_outputs)