yaghi27 commited on
Commit
e04f1e3
·
verified ·
1 Parent(s): cae9265

Update model/run_inference.py

Browse files
Files changed (1) hide show
  1. model/run_inference.py +8 -8
model/run_inference.py CHANGED
@@ -50,17 +50,17 @@ CAM_ORDER = [
50
  DUMMY_4x4 = np.eye(4, dtype=np.float32)
51
 
52
 
53
- def _download_ckpt_from_hf(repo_id: str, filename: str, cache_dir: str = "./.hf_cache") -> str:
54
- """
55
- Download a checkpoint from Hugging Face Hub. Uses HF_TOKEN env if set (for private repos).
56
- """
 
57
  ckpt_path = hf_hub_download(
58
  repo_id=repo_id,
59
  filename=filename,
60
- revision="main", # change to a tag/commit if you want pinning
61
- cache_dir=cache_dir, # keeps the file across runs
62
- local_dir=None, # rely on cache_dir
63
- token=os.environ.get("HF_TOKEN", None),
64
  )
65
  if not os.path.isfile(ckpt_path):
66
  raise FileNotFoundError(f"Failed to download checkpoint: {repo_id}/{filename}")
 
50
  DUMMY_4x4 = np.eye(4, dtype=np.float32)
51
 
52
 
53
+ def _download_ckpt_from_hf(repo_id: str, filename: str, cache_dir: str = None) -> str:
54
+ # default to a writable cache location in Spaces/containers
55
+ cache_dir = cache_dir or os.environ.get("HF_HUB_CACHE") or "/tmp/hf_cache"
56
+ os.makedirs(cache_dir, exist_ok=True)
57
+
58
  ckpt_path = hf_hub_download(
59
  repo_id=repo_id,
60
  filename=filename,
61
+ revision="main",
62
+ cache_dir=cache_dir,
63
+ token=os.environ.get("HF_TOKEN"),
 
64
  )
65
  if not os.path.isfile(ckpt_path):
66
  raise FileNotFoundError(f"Failed to download checkpoint: {repo_id}/{filename}")