Spaces:
Runtime error
Runtime error
Update model/run_inference.py
Browse files- model/run_inference.py +8 -8
model/run_inference.py
CHANGED
|
@@ -50,17 +50,17 @@ CAM_ORDER = [
|
|
| 50 |
DUMMY_4x4 = np.eye(4, dtype=np.float32)
|
| 51 |
|
| 52 |
|
| 53 |
-
def _download_ckpt_from_hf(repo_id: str, filename: str, cache_dir: str =
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
| 57 |
ckpt_path = hf_hub_download(
|
| 58 |
repo_id=repo_id,
|
| 59 |
filename=filename,
|
| 60 |
-
revision="main",
|
| 61 |
-
cache_dir=cache_dir,
|
| 62 |
-
|
| 63 |
-
token=os.environ.get("HF_TOKEN", None),
|
| 64 |
)
|
| 65 |
if not os.path.isfile(ckpt_path):
|
| 66 |
raise FileNotFoundError(f"Failed to download checkpoint: {repo_id}/{filename}")
|
|
|
|
| 50 |
DUMMY_4x4 = np.eye(4, dtype=np.float32)
|
| 51 |
|
| 52 |
|
| 53 |
+
def _download_ckpt_from_hf(repo_id: str, filename: str, cache_dir: str = None) -> str:
|
| 54 |
+
# default to a writable cache location in Spaces/containers
|
| 55 |
+
cache_dir = cache_dir or os.environ.get("HF_HUB_CACHE") or "/tmp/hf_cache"
|
| 56 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 57 |
+
|
| 58 |
ckpt_path = hf_hub_download(
|
| 59 |
repo_id=repo_id,
|
| 60 |
filename=filename,
|
| 61 |
+
revision="main",
|
| 62 |
+
cache_dir=cache_dir,
|
| 63 |
+
token=os.environ.get("HF_TOKEN"),
|
|
|
|
| 64 |
)
|
| 65 |
if not os.path.isfile(ckpt_path):
|
| 66 |
raise FileNotFoundError(f"Failed to download checkpoint: {repo_id}/{filename}")
|