import torch import torch.nn as nn class RNN_GEN(nn.Module): def __init__(self, num_countries, input_dim, output_dim, hidden_dim=128): super().__init__() self.input_to_hidden = nn.Linear(num_countries + input_dim + hidden_dim, hidden_dim) self.input_to_output = nn.Linear(num_countries + input_dim + hidden_dim, output_dim) self.output_to_output = nn.Linear(hidden_dim + output_dim, output_dim) self.log_softmax = nn.LogSoftmax(dim=1) self.hidden_dim = hidden_dim def forward(self, inp, country_tensor, hidden): combo = torch.cat([country_tensor, inp, hidden], dim=1) hidden = self.input_to_hidden(combo) out = self.input_to_output(combo) out_combo = torch.cat([hidden, out], dim=1) # out_combo will have dimensions [1, hidden_dim + output_dim] out = self.output_to_output(out_combo) # out will have dimensions [1, output_dim] out = self.log_softmax(out) return out, hidden def init_hidden(self): return torch.zeros((1, self.hidden_dim))