|
import torch |
|
import torch.nn as nn |
|
|
|
class LSTMPredictor(nn.Module): |
|
def __init__(self, input_dim, hidden_dim, output_dim, forecast_horizon=3, n_layers=2, dropout=0.2): |
|
super().__init__() |
|
self.hidden_dim = hidden_dim |
|
self.n_layers = n_layers |
|
self.forecast_horizon = forecast_horizon |
|
|
|
self.lstm = nn.LSTM(input_dim, hidden_dim, n_layers, batch_first=True, dropout=dropout) |
|
self.fc = nn.Linear(hidden_dim, output_dim * forecast_horizon) |
|
|
|
def forward(self, x): |
|
lstm_out, _ = self.lstm(x) |
|
predictions = self.fc(lstm_out[:, -1, :]) |
|
return predictions.view(-1, self.forecast_horizon, predictions.shape[-1] // self.forecast_horizon) |
|
|