risk_biased_prediction / import_model_from_huggingface.py
jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
raw
history blame
725 Bytes
from huggingface_hub import hf_hub_url, cached_download
from mmcv import Config
import torch
from risk_biased.utils.load_model import get_predictor
from risk_biased.utils.torch_utils import load_weights
from risk_biased.utils.waymo_dataloader import WaymoDataloaders
config_file = cached_download(hf_hub_url("jmercat/risk_biased_model", filename="learning_config.py"), force_filename="learing_config.py")
ckpt = torch.load(cached_download(hf_hub_url("jmercat/risk_biased_model", filename="last.ckpt"), force_filename="last.ckpt"), map_location="cpu")
cfg = Config.fromfile(config_file)
predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory)
predictor = load_weights(predictor, ckpt)
print("Model loaded")