|
import json |
|
import torch |
|
import numpy as np |
|
from torch.utils.data import DataLoader, Dataset |
|
from lstm_predictor import LSTMPredictor |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
class TimeSeriesDataset(Dataset): |
|
def __init__(self, data, seq_length): |
|
self.data = data |
|
self.seq_length = seq_length |
|
|
|
def __len__(self): |
|
return len(self.data) - self.seq_length |
|
|
|
def __getitem__(self, idx): |
|
return torch.tensor(self.data[idx:idx + self.seq_length], dtype=torch.float32) |
|
|
|
|
|
def load_config(config_path): |
|
with open(config_path, 'r') as file: |
|
config = json.load(file) |
|
return config |
|
|
|
|
|
def load_model(config): |
|
|
|
model_file = hf_hub_download(repo_id=config["repo_id"], filename=config["model_path"]) |
|
|
|
|
|
model = LSTMPredictor( |
|
input_dim=config["input_dim"], |
|
hidden_dim=config["hidden_dim"], |
|
output_dim=config["output_dim"], |
|
forecast_horizon=config["forecast_horizon"], |
|
n_layers=config["n_layers"], |
|
dropout=config["dropout"] |
|
) |
|
|
|
model.load_state_dict(torch.load(model_file, map_location=torch.device(config["device"]))) |
|
model.to(config["device"]) |
|
model.eval() |
|
return model |
|
|
|
|
|
def predict(model, dataloader, config): |
|
predictions = [] |
|
with torch.no_grad(): |
|
for batch in dataloader: |
|
batch = batch.to(config["device"]) |
|
output = model(batch) |
|
predictions.append(output.cpu().numpy()) |
|
return np.vstack(predictions) |
|
|
|
|
|
def main(): |
|
config_path = "config.json" |
|
config = load_config(config_path) |
|
|
|
|
|
raw_data = np.load(config["data_path"]) |
|
dataset = TimeSeriesDataset(raw_data, seq_length=config["seq_length"]) |
|
dataloader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=False) |
|
|
|
|
|
model = load_model(config) |
|
|
|
|
|
predictions = predict(model, dataloader, config) |
|
|
|
|
|
np.save(config["output_path"], predictions) |
|
print(f"Predictions saved to {config['output_path']}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|