File size: 2,133 Bytes
9a41f63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch.nn as nn
import torch
from models.attention import Attention
from utils.config import config

class Decoder(nn.Module):
    def __init__(self, output_dim, embedding_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.output_dim = output_dim
        self.attention = Attention(hidden_dim)
        self.embedding = nn.Embedding(output_dim, embedding_dim)
        self.rnn = nn.GRU(
            embedding_dim + hidden_dim, 
            hidden_dim,
            num_layers=n_layers,
            dropout=dropout if n_layers > 1 else 0
        )
        self.fc_out = nn.Linear(hidden_dim * 2, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input, hidden, encoder_outputs):
        # input: [batch_size]
        # hidden: [n_layers, batch_size, hidden_dim]
        # encoder_outputs: [src_len, batch_size, hidden_dim]
        
        input = input.unsqueeze(0)
        # input: [1, batch_size]
        
        embedded = self.dropout(self.embedding(input))
        # embedded: [1, batch_size, embedding_dim]
        
        a = self.attention(hidden[-1], encoder_outputs)
        # a: [src_len, batch_size]
        
        a = a.permute(1, 0).unsqueeze(1)
        # a: [batch_size, 1, src_len]
        
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        # encoder_outputs: [batch_size, src_len, hidden_dim]
        
        weighted = torch.bmm(a, encoder_outputs)
        weighted = weighted.permute(1, 0, 2)
        # weighted: [1, batch_size, hidden_dim]
        
        rnn_input = torch.cat((embedded, weighted), dim=2)
        # rnn_input: [1, batch_size, embedding_dim + hidden_dim]
        
        output, hidden = self.rnn(rnn_input, hidden)
        # output: [1, batch_size, hidden_dim]
        # hidden: [n_layers, batch_size, hidden_dim]
        
        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)
        
        prediction = self.fc_out(torch.cat((output, weighted), dim=1))
        # prediction: [batch_size, output_dim]
        
        return prediction, hidden