yaghi27 commited on
Commit
410ea00
·
verified ·
1 Parent(s): 3633323

Update model/run_inference.py

Browse files
Files changed (1) hide show
  1. model/run_inference.py +20 -0
model/run_inference.py CHANGED
@@ -29,6 +29,9 @@ from mmengine.registry import init_default_scope
29
  from mmdet3d.utils import register_all_modules
30
  register_all_modules(init_default_scope=False)
31
 
 
 
 
32
  @lru_cache(maxsize=2)
33
  def get_model(model_key: str):
34
  if model_key not in MODEL_REGISTRY:
@@ -40,6 +43,23 @@ def get_model(model_key: str):
40
  else: # detr3d
41
  purge_project_registrations("projects.PETR")
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  config_path, (repo_id, hf_file) = MODEL_REGISTRY[model_key]
44
  if not os.path.isfile(config_path):
45
  raise FileNotFoundError(f"Config not found: {config_path}")
 
29
  from mmdet3d.utils import register_all_modules
30
  register_all_modules(init_default_scope=False)
31
 
32
+ import importlib, sys
33
+ from pathlib import Path
34
+
35
  @lru_cache(maxsize=2)
36
  def get_model(model_key: str):
37
  if model_key not in MODEL_REGISTRY:
 
43
  else: # detr3d
44
  purge_project_registrations("projects.PETR")
45
 
46
+ # Ensure the mmdetection3d repo root (which contains `projects/`) is importable
47
+ repo_root = Path(__file__).resolve().parents[1] / "mmdetection3d"
48
+ if repo_root.is_dir() and str(repo_root) not in sys.path:
49
+ sys.path.insert(0, str(repo_root))
50
+
51
+ # Import the correct project so its registries (e.g., ResizeCropFlipImage) are registered
52
+ proj_name = "projects.PETR.petr" if "petr" in model_key.lower() else "projects.DETR3D.detr3d"
53
+ try:
54
+ importlib.import_module(proj_name)
55
+ except ModuleNotFoundError as e:
56
+ # Helpful error that tells you what path we tried
57
+ raise ModuleNotFoundError(
58
+ f"Could not import {proj_name}. Ensure the 'projects' package is on sys.path. "
59
+ f"Tried adding: {repo_root}"
60
+ ) from e
61
+
62
+
63
  config_path, (repo_id, hf_file) = MODEL_REGISTRY[model_key]
64
  if not os.path.isfile(config_path):
65
  raise FileNotFoundError(f"Config not found: {config_path}")