Crisis_Severity_Predictor_LSTM / lstm_predictor.py
sayyedAhmed's picture
adding files
7d950d3
raw
history blame
705 Bytes
import torch
import torch.nn as nn
class LSTMPredictor(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, forecast_horizon=3, n_layers=2, dropout=0.2):
super().__init__()
self.hidden_dim = hidden_dim
self.n_layers = n_layers
self.forecast_horizon = forecast_horizon
self.lstm = nn.LSTM(input_dim, hidden_dim, n_layers, batch_first=True, dropout=dropout)
self.fc = nn.Linear(hidden_dim, output_dim * forecast_horizon)
def forward(self, x):
lstm_out, _ = self.lstm(x)
predictions = self.fc(lstm_out[:, -1, :])
return predictions.view(-1, self.forecast_horizon, predictions.shape[-1] // self.forecast_horizon)