File size: 1,261 Bytes
8cc5633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import logging
import os

import torch
from huggingface_hub import hf_hub_download

from transformer_model.scripts.config_transformer import CHECKPOINT_DIR
from transformer_model.scripts.training.load_basis_model import \
    load_moment_model

logging.basicConfig(level=logging.INFO)


# load model from checkpoint if available, else download it from hugging face
def load_real_transformer_model(device=None):  # ⬅️ Name geändert
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = load_moment_model()
    filename = "model_final.pth"
    local_path = os.path.join(CHECKPOINT_DIR, filename)

    if os.path.exists(local_path):
        checkpoint_path = local_path
        print("Loading model from local path...")
    else:
        print("Downloading model from Hugging Face Hub...")
        checkpoint_path = hf_hub_download(
            repo_id="dlaj/energy-forecasting-files",  # passe ggf. an
            filename=f"transformer_model/{filename}",
            repo_type="dataset",
        )

    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model.to(device)
    model.eval()
    logging.info(f"Model loaded from: {checkpoint_path}")

    return model, device