File size: 1,490 Bytes
a9073bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
import torch

# import models
from models.LSTM import LSTM
from models.LSTNet import LSTNet
from models.Transformer import Transformer
from models.Autoformer import Autoformer
from models.Informer import Informer
from models.PatchTST import PatchTST
from models.TimesNet import TimesNet
from models.TimesFM import TimesFM

# import keyword args
from model_kwargs import *

# set lookback and lookahead. lookback is fixed to 512, while lookahead can be one among 4, 48, 96
# heterogeneity can be 'HET' or 'HOM'
lookback, lookahead, heterogeneity = 512, 48, 'HET'

if __name__ == "__main__":

    models = [LSTM, LSTNet, Transformer, Autoformer, Informer, PatchTST, TimesNet, TimesFM]
    kw_fns = [lstm_kwargs, lstnet_kwargs, transformer_kwargs, autoformer_kwargs, informer_kwargs, patchtst_kwargs, timesnet_kwargs, timesfm_kwargs]

    # loop over models and their keyword functions
    for model_class, kw_fn in zip(models,kw_fns):
        # load an object of the model class
        model = model_class(**kw_fn(lookback = lookback, lookahead = lookahead))
        # load the weight in the model
        result = model.load_state_dict(torch.load(os.path.join(*[os.getcwd(),'weights',f'{model_class.__name__}_L_{lookback}_T_{lookahead}_{heterogeneity}.pth']),map_location='cpu'))
        # print the outcome
        print(f"Loading weight for model {model_class.__name__}, lookback {lookback}, lookahead {lookahead}, heterogeneity {heterogeneity}, and the result was: {result}.")