Spaces:
Runtime error
Runtime error
Update model/run_inference.py
Browse files- model/run_inference.py +20 -0
model/run_inference.py
CHANGED
|
@@ -29,6 +29,9 @@ from mmengine.registry import init_default_scope
|
|
| 29 |
from mmdet3d.utils import register_all_modules
|
| 30 |
register_all_modules(init_default_scope=False)
|
| 31 |
|
|
|
|
|
|
|
|
|
|
| 32 |
@lru_cache(maxsize=2)
|
| 33 |
def get_model(model_key: str):
|
| 34 |
if model_key not in MODEL_REGISTRY:
|
|
@@ -40,6 +43,23 @@ def get_model(model_key: str):
|
|
| 40 |
else: # detr3d
|
| 41 |
purge_project_registrations("projects.PETR")
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
config_path, (repo_id, hf_file) = MODEL_REGISTRY[model_key]
|
| 44 |
if not os.path.isfile(config_path):
|
| 45 |
raise FileNotFoundError(f"Config not found: {config_path}")
|
|
|
|
| 29 |
from mmdet3d.utils import register_all_modules
|
| 30 |
register_all_modules(init_default_scope=False)
|
| 31 |
|
| 32 |
+
import importlib, sys
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
|
| 35 |
@lru_cache(maxsize=2)
|
| 36 |
def get_model(model_key: str):
|
| 37 |
if model_key not in MODEL_REGISTRY:
|
|
|
|
| 43 |
else: # detr3d
|
| 44 |
purge_project_registrations("projects.PETR")
|
| 45 |
|
| 46 |
+
# Ensure the mmdetection3d repo root (which contains `projects/`) is importable
|
| 47 |
+
repo_root = Path(__file__).resolve().parents[1] / "mmdetection3d"
|
| 48 |
+
if repo_root.is_dir() and str(repo_root) not in sys.path:
|
| 49 |
+
sys.path.insert(0, str(repo_root))
|
| 50 |
+
|
| 51 |
+
# Import the correct project so its registries (e.g., ResizeCropFlipImage) are registered
|
| 52 |
+
proj_name = "projects.PETR.petr" if "petr" in model_key.lower() else "projects.DETR3D.detr3d"
|
| 53 |
+
try:
|
| 54 |
+
importlib.import_module(proj_name)
|
| 55 |
+
except ModuleNotFoundError as e:
|
| 56 |
+
# Helpful error that tells you what path we tried
|
| 57 |
+
raise ModuleNotFoundError(
|
| 58 |
+
f"Could not import {proj_name}. Ensure the 'projects' package is on sys.path. "
|
| 59 |
+
f"Tried adding: {repo_root}"
|
| 60 |
+
) from e
|
| 61 |
+
|
| 62 |
+
|
| 63 |
config_path, (repo_id, hf_file) = MODEL_REGISTRY[model_key]
|
| 64 |
if not os.path.isfile(config_path):
|
| 65 |
raise FileNotFoundError(f"Config not found: {config_path}")
|