Spaces:
Sleeping
Sleeping
import logging | |
import os | |
import torch | |
from huggingface_hub import hf_hub_download | |
from transformer_model.scripts.config_transformer import CHECKPOINT_DIR | |
from transformer_model.scripts.training.load_basis_model import \ | |
load_moment_model | |
logging.basicConfig(level=logging.INFO) | |
# load model from checkpoint if available, else download it from hugging face | |
def load_real_transformer_model(device=None): # ⬅️ Name geändert | |
if device is None: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = load_moment_model() | |
filename = "model_final.pth" | |
local_path = os.path.join(CHECKPOINT_DIR, filename) | |
if os.path.exists(local_path): | |
checkpoint_path = local_path | |
print("Loading model from local path...") | |
else: | |
print("Downloading model from Hugging Face Hub...") | |
checkpoint_path = hf_hub_download( | |
repo_id="dlaj/energy-forecasting-files", # passe ggf. an | |
filename=f"transformer_model/{filename}", | |
repo_type="dataset", | |
) | |
model.load_state_dict(torch.load(checkpoint_path, map_location=device)) | |
model.to(device) | |
model.eval() | |
logging.info(f"Model loaded from: {checkpoint_path}") | |
return model, device | |