sayyedAhmed's picture
adding files
7d950d3
raw
history blame
2.29 kB
import json
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
from lstm_predictor import LSTMPredictor
from huggingface_hub import hf_hub_download
# Dataset Class
class TimeSeriesDataset(Dataset):
def __init__(self, data, seq_length):
self.data = data
self.seq_length = seq_length
def __len__(self):
return len(self.data) - self.seq_length
def __getitem__(self, idx):
return torch.tensor(self.data[idx:idx + self.seq_length], dtype=torch.float32)
# Load Config
def load_config(config_path):
with open(config_path, 'r') as file:
config = json.load(file)
return config
# Load Model from Hugging Face
def load_model(config):
# Download model from Hugging Face
model_file = hf_hub_download(repo_id=config["repo_id"], filename=config["model_path"])
# Load the model architecture
model = LSTMPredictor(
input_dim=config["input_dim"],
hidden_dim=config["hidden_dim"],
output_dim=config["output_dim"],
forecast_horizon=config["forecast_horizon"],
n_layers=config["n_layers"],
dropout=config["dropout"]
)
# Load weights
model.load_state_dict(torch.load(model_file, map_location=torch.device(config["device"])))
model.to(config["device"])
model.eval()
return model
# Prediction Function
def predict(model, dataloader, config):
predictions = []
with torch.no_grad():
for batch in dataloader:
batch = batch.to(config["device"])
output = model(batch)
predictions.append(output.cpu().numpy())
return np.vstack(predictions)
# Main Function
def main():
config_path = "config.json" # Path to config file
config = load_config(config_path)
# Load test data
raw_data = np.load(config["data_path"])
dataset = TimeSeriesDataset(raw_data, seq_length=config["seq_length"])
dataloader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=False)
# Load model
model = load_model(config)
# Predict
predictions = predict(model, dataloader, config)
# Save predictions
np.save(config["output_path"], predictions)
print(f"Predictions saved to {config['output_path']}")
if __name__ == "__main__":
main()