sparkles's picture
Upload model.py
23ea7d3
raw
history blame contribute delete
No virus
1.07 kB
import torch
import torch.nn as nn
class RNN_GEN(nn.Module):
def __init__(self, num_countries, input_dim, output_dim, hidden_dim=128):
super().__init__()
self.input_to_hidden = nn.Linear(num_countries + input_dim + hidden_dim, hidden_dim)
self.input_to_output = nn.Linear(num_countries + input_dim + hidden_dim, output_dim)
self.output_to_output = nn.Linear(hidden_dim + output_dim, output_dim)
self.log_softmax = nn.LogSoftmax(dim=1)
self.hidden_dim = hidden_dim
def forward(self, inp, country_tensor, hidden):
combo = torch.cat([country_tensor, inp, hidden], dim=1)
hidden = self.input_to_hidden(combo)
out = self.input_to_output(combo)
out_combo = torch.cat([hidden, out], dim=1) # out_combo will have dimensions [1, hidden_dim + output_dim]
out = self.output_to_output(out_combo) # out will have dimensions [1, output_dim]
out = self.log_softmax(out)
return out, hidden
def init_hidden(self):
return torch.zeros((1, self.hidden_dim))