File size: 1,069 Bytes
23ea7d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch 
import torch.nn as nn 

class RNN_GEN(nn.Module):
    def __init__(self, num_countries, input_dim, output_dim, hidden_dim=128):
        super().__init__()
        self.input_to_hidden = nn.Linear(num_countries + input_dim + hidden_dim, hidden_dim)
        self.input_to_output = nn.Linear(num_countries + input_dim + hidden_dim, output_dim)
        self.output_to_output = nn.Linear(hidden_dim + output_dim, output_dim)
        self.log_softmax = nn.LogSoftmax(dim=1)
        self.hidden_dim = hidden_dim

    def forward(self, inp, country_tensor, hidden):
        combo = torch.cat([country_tensor, inp, hidden], dim=1)
        hidden = self.input_to_hidden(combo)
        out = self.input_to_output(combo)
        out_combo = torch.cat([hidden, out], dim=1) # out_combo will have dimensions [1, hidden_dim + output_dim]
        out = self.output_to_output(out_combo) # out will have dimensions [1, output_dim]
        out = self.log_softmax(out)
        return out, hidden

    def init_hidden(self):
        return torch.zeros((1, self.hidden_dim))