import torch.nn as nn import torch class LSTMModel(nn.Module): ## constructor def __init__(self, input_size, hidden_size, output_size, num_layers): super(LSTMModel, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.output_size = output_size self.num_layers = num_layers self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True) self.fc = nn.Linear(self.hidden_size, self.output_size) def forward(self,x, h0=None, c0=None): # hidden and state vectors h0 and c0 if h0 is None or c0 is None: h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) out, (hn, cn) = self.lstm(x, (h0, c0)) out = self.fc(out) return out, (hn, cn)