import torch import torch.nn as nn class LSTMPredictor(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim, forecast_horizon=3, n_layers=2, dropout=0.2): super().__init__() self.hidden_dim = hidden_dim self.n_layers = n_layers self.forecast_horizon = forecast_horizon self.lstm = nn.LSTM(input_dim, hidden_dim, n_layers, batch_first=True, dropout=dropout) self.fc = nn.Linear(hidden_dim, output_dim * forecast_horizon) def forward(self, x): lstm_out, _ = self.lstm(x) predictions = self.fc(lstm_out[:, -1, :]) return predictions.view(-1, self.forecast_horizon, predictions.shape[-1] // self.forecast_horizon)