wuhp commited on
Commit
7411ec8
·
verified ·
1 Parent(s): bbfec58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -1
app.py CHANGED
@@ -458,6 +458,45 @@ def _install_supervisely_logger_shim():
458
  """))
459
  return str(root)
460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  def _ensure_checkpoint(model_key: str, out_dir: str) -> str | None:
462
  url = CKPT_URLS.get(model_key)
463
  if not url:
@@ -571,6 +610,9 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
571
 
572
  # Disable SyncBN for single GPU/CPU runs
573
  cfg["sync_bn"] = False
 
 
 
574
 
575
  # Remove COCO dataset include so it can't override our dataset paths later
576
  if "__include__" in cfg and isinstance(cfg["__include__"], list):
@@ -825,6 +867,10 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
825
  q = Queue()
826
  def run_train():
827
  try:
 
 
 
 
828
  env = os.environ.copy()
829
  # Make sure repo code can be imported
830
  env["PYTHONPATH"] = os.pathsep.join(filter(None, [
@@ -837,7 +883,7 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
837
  # Provide a secondary hint for some config loaders
838
  env.setdefault("RTDETR_PYMODULE", "rtdetrv2_pytorch.src")
839
 
840
- proc = subprocess.Popen(cmd, cwd=os.path.dirname(train_script),
841
  stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
842
  bufsize=1, text=True, env=env)
843
  for line in proc.stdout:
 
458
  """))
459
  return str(root)
460
 
461
+ # ---- NEW: robust fallback for cfg['_pymodule'] via sitecustomize -------------
462
+ def _install_workspace_env_fallback(cwd_for_train: str, module_default: str = "rtdetrv2_pytorch.src"):
463
+ """
464
+ Creates a sitecustomize.py in the training cwd that monkeypatches
465
+ rtdetrv2_pytorch.src.core.workspace.create to gracefully handle missing
466
+ cfg['_pymodule'] by falling back to $RTDETR_PYMODULE or module_default.
467
+ """
468
+ sc_path = os.path.join(cwd_for_train, "sitecustomize.py")
469
+ code = textwrap.dedent(f"""
470
+ import os, importlib
471
+ try:
472
+ mod_path = os.environ.get("RTDETR_PYMODULE", "{module_default}")
473
+ ws_mod = importlib.import_module("rtdetrv2_pytorch.src.core.workspace")
474
+ _orig_create = ws_mod.create
475
+ def _safe_create(name, cfg, *args, **kwargs):
476
+ pm = None
477
+ try:
478
+ pm = cfg.get("_pymodule", None)
479
+ except Exception:
480
+ pm = None
481
+ if not pm:
482
+ pm = os.environ.get("RTDETR_PYMODULE", "{module_default}")
483
+ try:
484
+ importlib.import_module(pm)
485
+ except Exception:
486
+ pass
487
+ try:
488
+ cfg["_pymodule"] = pm
489
+ except Exception:
490
+ pass
491
+ return _orig_create(name, cfg, *args, **kwargs)
492
+ ws_mod.create = _safe_create
493
+ except Exception:
494
+ pass
495
+ """)
496
+ with open(sc_path, "w", encoding="utf-8") as f:
497
+ f.write(code)
498
+ return sc_path
499
+
500
  def _ensure_checkpoint(model_key: str, out_dir: str) -> str | None:
501
  url = CKPT_URLS.get(model_key)
502
  if not url:
 
610
 
611
  # Disable SyncBN for single GPU/CPU runs
612
  cfg["sync_bn"] = False
613
+ # Guardrails for single-process runs
614
+ cfg.setdefault("device", "")
615
+ cfg["find_unused_parameters"] = False
616
 
617
  # Remove COCO dataset include so it can't override our dataset paths later
618
  if "__include__" in cfg and isinstance(cfg["__include__"], list):
 
867
  q = Queue()
868
  def run_train():
869
  try:
870
+ # Ensure our fallback hook is available in the train process (CWD on sys.path)
871
+ train_cwd = os.path.dirname(train_script)
872
+ _install_workspace_env_fallback(train_cwd)
873
+
874
  env = os.environ.copy()
875
  # Make sure repo code can be imported
876
  env["PYTHONPATH"] = os.pathsep.join(filter(None, [
 
883
  # Provide a secondary hint for some config loaders
884
  env.setdefault("RTDETR_PYMODULE", "rtdetrv2_pytorch.src")
885
 
886
+ proc = subprocess.Popen(cmd, cwd=train_cwd,
887
  stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
888
  bufsize=1, text=True, env=env)
889
  for line in proc.stdout: