wuhp commited on
Commit
bbfec58
·
verified ·
1 Parent(s): 04a0a1c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -40
app.py CHANGED
@@ -21,10 +21,9 @@ REPO_URL = "https://github.com/supervisely-ecosystem/RT-DETRv2"
21
  REPO_DIR = os.path.join(os.getcwd(), "third_party", "RT-DETRv2")
22
  PY_IMPL_DIR = os.path.join(REPO_DIR, "rtdetrv2_pytorch") # Supervisely keeps PyTorch impl here
23
 
24
- # Core deps + your requested packages; pinned as lower-bounds to avoid downgrades (local runs only)
25
  COMMON_REQUIREMENTS = [
26
  "gradio>=4.36.1",
27
- "ultralytics>=8.2.0",
28
  "roboflow>=1.1.28",
29
  "requests>=2.31.0",
30
  "huggingface_hub>=0.22.0",
@@ -427,38 +426,28 @@ def _set_first_existing_key(d: dict, keys: list, value, fallback_key: str | None
427
  return None
428
 
429
  def _set_first_existing_key_deep(cfg: dict, keys: list, value):
430
- """
431
- Try to set one of `keys` at top-level, under 'model', or under 'solver'.
432
- """
433
  for scope in [cfg, cfg.get("model", {}), cfg.get("solver", {})]:
434
  if isinstance(scope, dict):
435
  for k in keys:
436
  if k in scope:
437
  scope[k] = value
438
  return True
439
- # If nowhere found, set on model
440
  if "model" not in cfg or not isinstance(cfg["model"], dict):
441
  cfg["model"] = {}
442
  cfg["model"][keys[0]] = value
443
  return True
444
 
445
  def _install_supervisely_logger_shim():
446
- """
447
- Create a package shim so 'from supervisely.nn.training import train_logger' works.
448
- """
449
  root = pathlib.Path(tempfile.gettempdir()) / "sly_shim_pkg"
450
  pkg_training = root / "supervisely" / "nn" / "training"
451
  pkg_training.mkdir(parents=True, exist_ok=True)
452
 
453
- # Make each level a package
454
  for p in [root / "supervisely", root / "supervisely" / "nn", pkg_training]:
455
  init_file = p / "__init__.py"
456
  if not init_file.exists():
457
  init_file.write_text("")
458
 
459
- # Expose train_logger from the package's __init__
460
  (pkg_training / "__init__.py").write_text(textwrap.dedent("""
461
- # Minimal shim for backward-compat with older RT-DETRv2 training code.
462
  class _TrainLogger:
463
  def __init__(self): pass
464
  def reset(self): pass
@@ -498,10 +487,6 @@ def _ensure_checkpoint(model_key: str, out_dir: str) -> str | None:
498
  # --- include absolutizer ------------------------------------------------------
499
  def _absify_any_paths_deep(node, base_dir, include_keys=("base", "_base_", "BASE", "BASE_YAML",
500
  "includes", "include", "BASES", "__include__")):
501
- """
502
- Walk dict/list; for known include keys or strings that look like ../*.yml/.yaml,
503
- make them absolute against base_dir.
504
- """
505
  def _absify(s: str) -> str:
506
  if os.path.isabs(s):
507
  return s
@@ -531,9 +516,6 @@ def _absify_any_paths_deep(node, base_dir, include_keys=("base", "_base_", "BASE
531
 
532
  # --- NEW: safe model field setters --------------------------------------------
533
  def _set_num_classes_safely(cfg: dict, n: int):
534
- """
535
- Set class count without breaking templates that use `model: "RTDETR"` indirection.
536
- """
537
  def set_num_classes(node):
538
  if not isinstance(node, dict):
539
  return False
@@ -558,12 +540,9 @@ def _set_num_classes_safely(cfg: dict, n: int):
558
  block["num_classes"] = int(n)
559
  return
560
 
561
- cfg["num_classes"] = int(n) # last resort
562
 
563
  def _maybe_set_model_field(cfg: dict, key: str, value):
564
- """
565
- Place fields like 'pretrain' under the proper model dict, respecting string indirection.
566
- """
567
  m = cfg.get("model", None)
568
  if isinstance(m, dict):
569
  m[key] = value
@@ -571,7 +550,7 @@ def _maybe_set_model_field(cfg: dict, key: str, value):
571
  if isinstance(m, str) and isinstance(cfg.get(m), dict):
572
  cfg[m][key] = value
573
  return
574
- cfg[key] = value # fallback
575
 
576
  # --- CRITICAL: dataset override + include cleanup + sync_bn off ---------------
577
  def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
@@ -586,6 +565,10 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
586
  cfg = yaml.safe_load(f)
587
  _absify_any_paths_deep(cfg, template_dir)
588
 
 
 
 
 
589
  # Disable SyncBN for single GPU/CPU runs
590
  cfg["sync_bn"] = False
591
 
@@ -607,7 +590,6 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
607
  "out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
608
  }
609
 
610
- # Ensure/patch dataloaders to point to our dataset
611
  def ensure_and_patch_dl(dl_key, img_key, json_key, default_shuffle):
612
  block = cfg.get(dl_key)
613
  if not isinstance(block, dict):
@@ -634,7 +616,6 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
634
  }
635
  cfg[dl_key] = block
636
 
637
- # Patch existing block
638
  ds = block.get("dataset", {})
639
  if isinstance(ds, dict):
640
  ds["img_folder"] = paths[img_key]
@@ -652,13 +633,9 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
652
 
653
  ensure_and_patch_dl("train_dataloader", "train_img", "train_json", default_shuffle=True)
654
  ensure_and_patch_dl("val_dataloader", "val_img", "val_json", default_shuffle=False)
655
- # Optional test loader
656
- # ensure_and_patch_dl("test_dataloader", "test_img", "test_json", default_shuffle=False)
657
 
658
- # num classes (handles model: "RTDETR")
659
  _set_num_classes_safely(cfg, int(class_count))
660
 
661
- # epochs / imgsz
662
  applied_epoch = False
663
  for key in ("epoches", "max_epoch", "epochs", "num_epochs"):
664
  if key in cfg:
@@ -675,7 +652,6 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
675
  cfg["epoches"] = int(epochs)
676
  cfg["input_size"] = int(imgsz)
677
 
678
- # lr / optimizer / batch
679
  if "solver" not in cfg or not isinstance(cfg["solver"], dict):
680
  cfg["solver"] = {}
681
  sol = cfg["solver"]
@@ -689,24 +665,20 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
689
  if "train_dataloader" not in cfg or not isinstance(cfg["train_dataloader"], dict):
690
  sol["batch_size"] = int(batch)
691
 
692
- # output dir
693
  if "output_dir" in cfg:
694
  cfg["output_dir"] = paths["out_dir"]
695
  else:
696
  sol["output_dir"] = paths["out_dir"]
697
 
698
- # pretrained weights in the right model block
699
  if pretrained_path:
700
  p = os.path.abspath(pretrained_path)
701
  _maybe_set_model_field(cfg, "pretrain", p)
702
  _maybe_set_model_field(cfg, "pretrained", p)
703
 
704
- # Save near the template so internal relative references still make sense
705
  cfg_out_dir = os.path.join(template_dir, "generated")
706
  os.makedirs(cfg_out_dir, exist_ok=True)
707
  out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
708
 
709
- # Force block style for lists (no inline [a, b, c])
710
  class _NoFlowDumper(yaml.SafeDumper): ...
711
  def _repr_list_block(dumper, data):
712
  return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=False)
@@ -832,7 +804,6 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
832
  out_dir = os.path.abspath(os.path.join("runs", "train", run_name))
833
  os.makedirs(out_dir, exist_ok=True)
834
 
835
- # Download matching COCO checkpoint for warm-start
836
  pretrained_path = _ensure_checkpoint(model_key, out_dir)
837
 
838
  cfg_path = patch_base_config(
@@ -855,13 +826,17 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
855
  def run_train():
856
  try:
857
  env = os.environ.copy()
 
858
  env["PYTHONPATH"] = os.pathsep.join(filter(None, [
859
  PY_IMPL_DIR, REPO_DIR, env.get("PYTHONPATH", "")
860
  ]))
861
- # put our shim at the very front so the import always resolves
862
  shim_root = _install_supervisely_logger_shim()
863
  env["PYTHONPATH"] = os.pathsep.join([shim_root, env["PYTHONPATH"]])
864
  env.setdefault("WANDB_DISABLED", "true")
 
 
 
865
  proc = subprocess.Popen(cmd, cwd=os.path.dirname(train_script),
866
  stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
867
  bufsize=1, text=True, env=env)
@@ -882,13 +857,13 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
882
  if line.startswith("__EXITCODE__"):
883
  code = int(line.split(":", 1)[1])
884
  if code != 0:
885
- head = "\n".join(first_lines[:60])
886
  raise gr.Error(f"Training exited with code {code}.\nLast output:\n{head or 'No logs captured.'}")
887
  break
888
  if line.startswith("__ERROR__"):
889
  raise gr.Error(f"Training failed: {line.split(':', 1)[1]}")
890
 
891
- if len(first_lines) < 120:
892
  first_lines.append(line)
893
  log_tail.append(line)
894
  log_tail = log_tail[-40:]
@@ -902,7 +877,6 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
902
  pass
903
  progress(min(max(last_epoch / max(1, total_epochs), 0.0), 1.0), desc=f"Epoch {last_epoch}/{total_epochs}")
904
 
905
- # Throttle plotting; close figs after yield to avoid leaks
906
  line_no += 1
907
  fig1 = fig2 = None
908
  if line_no % 80 == 0:
 
21
  REPO_DIR = os.path.join(os.getcwd(), "third_party", "RT-DETRv2")
22
  PY_IMPL_DIR = os.path.join(REPO_DIR, "rtdetrv2_pytorch") # Supervisely keeps PyTorch impl here
23
 
24
+ # Core deps Ultralytics removed per request
25
  COMMON_REQUIREMENTS = [
26
  "gradio>=4.36.1",
 
27
  "roboflow>=1.1.28",
28
  "requests>=2.31.0",
29
  "huggingface_hub>=0.22.0",
 
426
  return None
427
 
428
  def _set_first_existing_key_deep(cfg: dict, keys: list, value):
 
 
 
429
  for scope in [cfg, cfg.get("model", {}), cfg.get("solver", {})]:
430
  if isinstance(scope, dict):
431
  for k in keys:
432
  if k in scope:
433
  scope[k] = value
434
  return True
 
435
  if "model" not in cfg or not isinstance(cfg["model"], dict):
436
  cfg["model"] = {}
437
  cfg["model"][keys[0]] = value
438
  return True
439
 
440
  def _install_supervisely_logger_shim():
 
 
 
441
  root = pathlib.Path(tempfile.gettempdir()) / "sly_shim_pkg"
442
  pkg_training = root / "supervisely" / "nn" / "training"
443
  pkg_training.mkdir(parents=True, exist_ok=True)
444
 
 
445
  for p in [root / "supervisely", root / "supervisely" / "nn", pkg_training]:
446
  init_file = p / "__init__.py"
447
  if not init_file.exists():
448
  init_file.write_text("")
449
 
 
450
  (pkg_training / "__init__.py").write_text(textwrap.dedent("""
 
451
  class _TrainLogger:
452
  def __init__(self): pass
453
  def reset(self): pass
 
487
  # --- include absolutizer ------------------------------------------------------
488
  def _absify_any_paths_deep(node, base_dir, include_keys=("base", "_base_", "BASE", "BASE_YAML",
489
  "includes", "include", "BASES", "__include__")):
 
 
 
 
490
  def _absify(s: str) -> str:
491
  if os.path.isabs(s):
492
  return s
 
516
 
517
  # --- NEW: safe model field setters --------------------------------------------
518
  def _set_num_classes_safely(cfg: dict, n: int):
 
 
 
519
  def set_num_classes(node):
520
  if not isinstance(node, dict):
521
  return False
 
540
  block["num_classes"] = int(n)
541
  return
542
 
543
+ cfg["num_classes"] = int(n)
544
 
545
  def _maybe_set_model_field(cfg: dict, key: str, value):
 
 
 
546
  m = cfg.get("model", None)
547
  if isinstance(m, dict):
548
  m[key] = value
 
550
  if isinstance(m, str) and isinstance(cfg.get(m), dict):
551
  cfg[m][key] = value
552
  return
553
+ cfg[key] = value
554
 
555
  # --- CRITICAL: dataset override + include cleanup + sync_bn off ---------------
556
  def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
 
565
  cfg = yaml.safe_load(f)
566
  _absify_any_paths_deep(cfg, template_dir)
567
 
568
+ # Ensure the runtime knows which Python module hosts builders
569
+ cfg["task"] = cfg.get("task", "detection")
570
+ cfg["_pymodule"] = cfg.get("_pymodule", "rtdetrv2_pytorch.src") # <= HINT for loader
571
+
572
  # Disable SyncBN for single GPU/CPU runs
573
  cfg["sync_bn"] = False
574
 
 
590
  "out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
591
  }
592
 
 
593
  def ensure_and_patch_dl(dl_key, img_key, json_key, default_shuffle):
594
  block = cfg.get(dl_key)
595
  if not isinstance(block, dict):
 
616
  }
617
  cfg[dl_key] = block
618
 
 
619
  ds = block.get("dataset", {})
620
  if isinstance(ds, dict):
621
  ds["img_folder"] = paths[img_key]
 
633
 
634
  ensure_and_patch_dl("train_dataloader", "train_img", "train_json", default_shuffle=True)
635
  ensure_and_patch_dl("val_dataloader", "val_img", "val_json", default_shuffle=False)
 
 
636
 
 
637
  _set_num_classes_safely(cfg, int(class_count))
638
 
 
639
  applied_epoch = False
640
  for key in ("epoches", "max_epoch", "epochs", "num_epochs"):
641
  if key in cfg:
 
652
  cfg["epoches"] = int(epochs)
653
  cfg["input_size"] = int(imgsz)
654
 
 
655
  if "solver" not in cfg or not isinstance(cfg["solver"], dict):
656
  cfg["solver"] = {}
657
  sol = cfg["solver"]
 
665
  if "train_dataloader" not in cfg or not isinstance(cfg["train_dataloader"], dict):
666
  sol["batch_size"] = int(batch)
667
 
 
668
  if "output_dir" in cfg:
669
  cfg["output_dir"] = paths["out_dir"]
670
  else:
671
  sol["output_dir"] = paths["out_dir"]
672
 
 
673
  if pretrained_path:
674
  p = os.path.abspath(pretrained_path)
675
  _maybe_set_model_field(cfg, "pretrain", p)
676
  _maybe_set_model_field(cfg, "pretrained", p)
677
 
 
678
  cfg_out_dir = os.path.join(template_dir, "generated")
679
  os.makedirs(cfg_out_dir, exist_ok=True)
680
  out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
681
 
 
682
  class _NoFlowDumper(yaml.SafeDumper): ...
683
  def _repr_list_block(dumper, data):
684
  return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=False)
 
804
  out_dir = os.path.abspath(os.path.join("runs", "train", run_name))
805
  os.makedirs(out_dir, exist_ok=True)
806
 
 
807
  pretrained_path = _ensure_checkpoint(model_key, out_dir)
808
 
809
  cfg_path = patch_base_config(
 
826
  def run_train():
827
  try:
828
  env = os.environ.copy()
829
+ # Make sure repo code can be imported
830
  env["PYTHONPATH"] = os.pathsep.join(filter(None, [
831
  PY_IMPL_DIR, REPO_DIR, env.get("PYTHONPATH", "")
832
  ]))
833
+ # Put our shim first so supervisely import never breaks
834
  shim_root = _install_supervisely_logger_shim()
835
  env["PYTHONPATH"] = os.pathsep.join([shim_root, env["PYTHONPATH"]])
836
  env.setdefault("WANDB_DISABLED", "true")
837
+ # Provide a secondary hint for some config loaders
838
+ env.setdefault("RTDETR_PYMODULE", "rtdetrv2_pytorch.src")
839
+
840
  proc = subprocess.Popen(cmd, cwd=os.path.dirname(train_script),
841
  stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
842
  bufsize=1, text=True, env=env)
 
857
  if line.startswith("__EXITCODE__"):
858
  code = int(line.split(":", 1)[1])
859
  if code != 0:
860
+ head = "\n".join(first_lines[-200:])
861
  raise gr.Error(f"Training exited with code {code}.\nLast output:\n{head or 'No logs captured.'}")
862
  break
863
  if line.startswith("__ERROR__"):
864
  raise gr.Error(f"Training failed: {line.split(':', 1)[1]}")
865
 
866
+ if len(first_lines) < 2000:
867
  first_lines.append(line)
868
  log_tail.append(line)
869
  log_tail = log_tail[-40:]
 
877
  pass
878
  progress(min(max(last_epoch / max(1, total_epochs), 0.0), 1.0), desc=f"Epoch {last_epoch}/{total_epochs}")
879
 
 
880
  line_no += 1
881
  fig1 = fig2 = None
882
  if line_no % 80 == 0: