dlaj's picture
Deploy from GitHub
8cc5633
import logging
import os
import torch
from huggingface_hub import hf_hub_download
from transformer_model.scripts.config_transformer import CHECKPOINT_DIR
from transformer_model.scripts.training.load_basis_model import \
load_moment_model
logging.basicConfig(level=logging.INFO)
# load model from checkpoint if available, else download it from hugging face
def load_real_transformer_model(device=None): # ⬅️ Name geändert
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_moment_model()
filename = "model_final.pth"
local_path = os.path.join(CHECKPOINT_DIR, filename)
if os.path.exists(local_path):
checkpoint_path = local_path
print("Loading model from local path...")
else:
print("Downloading model from Hugging Face Hub...")
checkpoint_path = hf_hub_download(
repo_id="dlaj/energy-forecasting-files", # passe ggf. an
filename=f"transformer_model/{filename}",
repo_type="dataset",
)
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.to(device)
model.eval()
logging.info(f"Model loaded from: {checkpoint_path}")
return model, device