File size: 725 Bytes
5769ee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from huggingface_hub import hf_hub_url, cached_download
from mmcv import Config
import torch

from risk_biased.utils.load_model import get_predictor
from risk_biased.utils.torch_utils import load_weights
from risk_biased.utils.waymo_dataloader import WaymoDataloaders


config_file = cached_download(hf_hub_url("jmercat/risk_biased_model", filename="learning_config.py"), force_filename="learing_config.py")
ckpt = torch.load(cached_download(hf_hub_url("jmercat/risk_biased_model", filename="last.ckpt"), force_filename="last.ckpt"), map_location="cpu")
cfg = Config.fromfile(config_file)
predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory)
predictor = load_weights(predictor, ckpt)

print("Model loaded")