Update app.py
Browse files
app.py
CHANGED
|
@@ -21,10 +21,9 @@ REPO_URL = "https://github.com/supervisely-ecosystem/RT-DETRv2"
|
|
| 21 |
REPO_DIR = os.path.join(os.getcwd(), "third_party", "RT-DETRv2")
|
| 22 |
PY_IMPL_DIR = os.path.join(REPO_DIR, "rtdetrv2_pytorch") # Supervisely keeps PyTorch impl here
|
| 23 |
|
| 24 |
-
# Core deps
|
| 25 |
COMMON_REQUIREMENTS = [
|
| 26 |
"gradio>=4.36.1",
|
| 27 |
-
"ultralytics>=8.2.0",
|
| 28 |
"roboflow>=1.1.28",
|
| 29 |
"requests>=2.31.0",
|
| 30 |
"huggingface_hub>=0.22.0",
|
|
@@ -427,38 +426,28 @@ def _set_first_existing_key(d: dict, keys: list, value, fallback_key: str | None
|
|
| 427 |
return None
|
| 428 |
|
| 429 |
def _set_first_existing_key_deep(cfg: dict, keys: list, value):
|
| 430 |
-
"""
|
| 431 |
-
Try to set one of `keys` at top-level, under 'model', or under 'solver'.
|
| 432 |
-
"""
|
| 433 |
for scope in [cfg, cfg.get("model", {}), cfg.get("solver", {})]:
|
| 434 |
if isinstance(scope, dict):
|
| 435 |
for k in keys:
|
| 436 |
if k in scope:
|
| 437 |
scope[k] = value
|
| 438 |
return True
|
| 439 |
-
# If nowhere found, set on model
|
| 440 |
if "model" not in cfg or not isinstance(cfg["model"], dict):
|
| 441 |
cfg["model"] = {}
|
| 442 |
cfg["model"][keys[0]] = value
|
| 443 |
return True
|
| 444 |
|
| 445 |
def _install_supervisely_logger_shim():
|
| 446 |
-
"""
|
| 447 |
-
Create a package shim so 'from supervisely.nn.training import train_logger' works.
|
| 448 |
-
"""
|
| 449 |
root = pathlib.Path(tempfile.gettempdir()) / "sly_shim_pkg"
|
| 450 |
pkg_training = root / "supervisely" / "nn" / "training"
|
| 451 |
pkg_training.mkdir(parents=True, exist_ok=True)
|
| 452 |
|
| 453 |
-
# Make each level a package
|
| 454 |
for p in [root / "supervisely", root / "supervisely" / "nn", pkg_training]:
|
| 455 |
init_file = p / "__init__.py"
|
| 456 |
if not init_file.exists():
|
| 457 |
init_file.write_text("")
|
| 458 |
|
| 459 |
-
# Expose train_logger from the package's __init__
|
| 460 |
(pkg_training / "__init__.py").write_text(textwrap.dedent("""
|
| 461 |
-
# Minimal shim for backward-compat with older RT-DETRv2 training code.
|
| 462 |
class _TrainLogger:
|
| 463 |
def __init__(self): pass
|
| 464 |
def reset(self): pass
|
|
@@ -498,10 +487,6 @@ def _ensure_checkpoint(model_key: str, out_dir: str) -> str | None:
|
|
| 498 |
# --- include absolutizer ------------------------------------------------------
|
| 499 |
def _absify_any_paths_deep(node, base_dir, include_keys=("base", "_base_", "BASE", "BASE_YAML",
|
| 500 |
"includes", "include", "BASES", "__include__")):
|
| 501 |
-
"""
|
| 502 |
-
Walk dict/list; for known include keys or strings that look like ../*.yml/.yaml,
|
| 503 |
-
make them absolute against base_dir.
|
| 504 |
-
"""
|
| 505 |
def _absify(s: str) -> str:
|
| 506 |
if os.path.isabs(s):
|
| 507 |
return s
|
|
@@ -531,9 +516,6 @@ def _absify_any_paths_deep(node, base_dir, include_keys=("base", "_base_", "BASE
|
|
| 531 |
|
| 532 |
# --- NEW: safe model field setters --------------------------------------------
|
| 533 |
def _set_num_classes_safely(cfg: dict, n: int):
|
| 534 |
-
"""
|
| 535 |
-
Set class count without breaking templates that use `model: "RTDETR"` indirection.
|
| 536 |
-
"""
|
| 537 |
def set_num_classes(node):
|
| 538 |
if not isinstance(node, dict):
|
| 539 |
return False
|
|
@@ -558,12 +540,9 @@ def _set_num_classes_safely(cfg: dict, n: int):
|
|
| 558 |
block["num_classes"] = int(n)
|
| 559 |
return
|
| 560 |
|
| 561 |
-
cfg["num_classes"] = int(n)
|
| 562 |
|
| 563 |
def _maybe_set_model_field(cfg: dict, key: str, value):
|
| 564 |
-
"""
|
| 565 |
-
Place fields like 'pretrain' under the proper model dict, respecting string indirection.
|
| 566 |
-
"""
|
| 567 |
m = cfg.get("model", None)
|
| 568 |
if isinstance(m, dict):
|
| 569 |
m[key] = value
|
|
@@ -571,7 +550,7 @@ def _maybe_set_model_field(cfg: dict, key: str, value):
|
|
| 571 |
if isinstance(m, str) and isinstance(cfg.get(m), dict):
|
| 572 |
cfg[m][key] = value
|
| 573 |
return
|
| 574 |
-
cfg[key] = value
|
| 575 |
|
| 576 |
# --- CRITICAL: dataset override + include cleanup + sync_bn off ---------------
|
| 577 |
def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
@@ -586,6 +565,10 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 586 |
cfg = yaml.safe_load(f)
|
| 587 |
_absify_any_paths_deep(cfg, template_dir)
|
| 588 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 589 |
# Disable SyncBN for single GPU/CPU runs
|
| 590 |
cfg["sync_bn"] = False
|
| 591 |
|
|
@@ -607,7 +590,6 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 607 |
"out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
|
| 608 |
}
|
| 609 |
|
| 610 |
-
# Ensure/patch dataloaders to point to our dataset
|
| 611 |
def ensure_and_patch_dl(dl_key, img_key, json_key, default_shuffle):
|
| 612 |
block = cfg.get(dl_key)
|
| 613 |
if not isinstance(block, dict):
|
|
@@ -634,7 +616,6 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 634 |
}
|
| 635 |
cfg[dl_key] = block
|
| 636 |
|
| 637 |
-
# Patch existing block
|
| 638 |
ds = block.get("dataset", {})
|
| 639 |
if isinstance(ds, dict):
|
| 640 |
ds["img_folder"] = paths[img_key]
|
|
@@ -652,13 +633,9 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 652 |
|
| 653 |
ensure_and_patch_dl("train_dataloader", "train_img", "train_json", default_shuffle=True)
|
| 654 |
ensure_and_patch_dl("val_dataloader", "val_img", "val_json", default_shuffle=False)
|
| 655 |
-
# Optional test loader
|
| 656 |
-
# ensure_and_patch_dl("test_dataloader", "test_img", "test_json", default_shuffle=False)
|
| 657 |
|
| 658 |
-
# num classes (handles model: "RTDETR")
|
| 659 |
_set_num_classes_safely(cfg, int(class_count))
|
| 660 |
|
| 661 |
-
# epochs / imgsz
|
| 662 |
applied_epoch = False
|
| 663 |
for key in ("epoches", "max_epoch", "epochs", "num_epochs"):
|
| 664 |
if key in cfg:
|
|
@@ -675,7 +652,6 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 675 |
cfg["epoches"] = int(epochs)
|
| 676 |
cfg["input_size"] = int(imgsz)
|
| 677 |
|
| 678 |
-
# lr / optimizer / batch
|
| 679 |
if "solver" not in cfg or not isinstance(cfg["solver"], dict):
|
| 680 |
cfg["solver"] = {}
|
| 681 |
sol = cfg["solver"]
|
|
@@ -689,24 +665,20 @@ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
| 689 |
if "train_dataloader" not in cfg or not isinstance(cfg["train_dataloader"], dict):
|
| 690 |
sol["batch_size"] = int(batch)
|
| 691 |
|
| 692 |
-
# output dir
|
| 693 |
if "output_dir" in cfg:
|
| 694 |
cfg["output_dir"] = paths["out_dir"]
|
| 695 |
else:
|
| 696 |
sol["output_dir"] = paths["out_dir"]
|
| 697 |
|
| 698 |
-
# pretrained weights in the right model block
|
| 699 |
if pretrained_path:
|
| 700 |
p = os.path.abspath(pretrained_path)
|
| 701 |
_maybe_set_model_field(cfg, "pretrain", p)
|
| 702 |
_maybe_set_model_field(cfg, "pretrained", p)
|
| 703 |
|
| 704 |
-
# Save near the template so internal relative references still make sense
|
| 705 |
cfg_out_dir = os.path.join(template_dir, "generated")
|
| 706 |
os.makedirs(cfg_out_dir, exist_ok=True)
|
| 707 |
out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
|
| 708 |
|
| 709 |
-
# Force block style for lists (no inline [a, b, c])
|
| 710 |
class _NoFlowDumper(yaml.SafeDumper): ...
|
| 711 |
def _repr_list_block(dumper, data):
|
| 712 |
return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=False)
|
|
@@ -832,7 +804,6 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
| 832 |
out_dir = os.path.abspath(os.path.join("runs", "train", run_name))
|
| 833 |
os.makedirs(out_dir, exist_ok=True)
|
| 834 |
|
| 835 |
-
# Download matching COCO checkpoint for warm-start
|
| 836 |
pretrained_path = _ensure_checkpoint(model_key, out_dir)
|
| 837 |
|
| 838 |
cfg_path = patch_base_config(
|
|
@@ -855,13 +826,17 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
| 855 |
def run_train():
|
| 856 |
try:
|
| 857 |
env = os.environ.copy()
|
|
|
|
| 858 |
env["PYTHONPATH"] = os.pathsep.join(filter(None, [
|
| 859 |
PY_IMPL_DIR, REPO_DIR, env.get("PYTHONPATH", "")
|
| 860 |
]))
|
| 861 |
-
#
|
| 862 |
shim_root = _install_supervisely_logger_shim()
|
| 863 |
env["PYTHONPATH"] = os.pathsep.join([shim_root, env["PYTHONPATH"]])
|
| 864 |
env.setdefault("WANDB_DISABLED", "true")
|
|
|
|
|
|
|
|
|
|
| 865 |
proc = subprocess.Popen(cmd, cwd=os.path.dirname(train_script),
|
| 866 |
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
| 867 |
bufsize=1, text=True, env=env)
|
|
@@ -882,13 +857,13 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
| 882 |
if line.startswith("__EXITCODE__"):
|
| 883 |
code = int(line.split(":", 1)[1])
|
| 884 |
if code != 0:
|
| 885 |
-
head = "\n".join(first_lines[:
|
| 886 |
raise gr.Error(f"Training exited with code {code}.\nLast output:\n{head or 'No logs captured.'}")
|
| 887 |
break
|
| 888 |
if line.startswith("__ERROR__"):
|
| 889 |
raise gr.Error(f"Training failed: {line.split(':', 1)[1]}")
|
| 890 |
|
| 891 |
-
if len(first_lines) <
|
| 892 |
first_lines.append(line)
|
| 893 |
log_tail.append(line)
|
| 894 |
log_tail = log_tail[-40:]
|
|
@@ -902,7 +877,6 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
|
|
| 902 |
pass
|
| 903 |
progress(min(max(last_epoch / max(1, total_epochs), 0.0), 1.0), desc=f"Epoch {last_epoch}/{total_epochs}")
|
| 904 |
|
| 905 |
-
# Throttle plotting; close figs after yield to avoid leaks
|
| 906 |
line_no += 1
|
| 907 |
fig1 = fig2 = None
|
| 908 |
if line_no % 80 == 0:
|
|
|
|
| 21 |
REPO_DIR = os.path.join(os.getcwd(), "third_party", "RT-DETRv2")
|
| 22 |
PY_IMPL_DIR = os.path.join(REPO_DIR, "rtdetrv2_pytorch") # Supervisely keeps PyTorch impl here
|
| 23 |
|
| 24 |
+
# Core deps — Ultralytics removed per request
|
| 25 |
COMMON_REQUIREMENTS = [
|
| 26 |
"gradio>=4.36.1",
|
|
|
|
| 27 |
"roboflow>=1.1.28",
|
| 28 |
"requests>=2.31.0",
|
| 29 |
"huggingface_hub>=0.22.0",
|
|
|
|
| 426 |
return None
|
| 427 |
|
| 428 |
def _set_first_existing_key_deep(cfg: dict, keys: list, value):
|
|
|
|
|
|
|
|
|
|
| 429 |
for scope in [cfg, cfg.get("model", {}), cfg.get("solver", {})]:
|
| 430 |
if isinstance(scope, dict):
|
| 431 |
for k in keys:
|
| 432 |
if k in scope:
|
| 433 |
scope[k] = value
|
| 434 |
return True
|
|
|
|
| 435 |
if "model" not in cfg or not isinstance(cfg["model"], dict):
|
| 436 |
cfg["model"] = {}
|
| 437 |
cfg["model"][keys[0]] = value
|
| 438 |
return True
|
| 439 |
|
| 440 |
def _install_supervisely_logger_shim():
|
|
|
|
|
|
|
|
|
|
| 441 |
root = pathlib.Path(tempfile.gettempdir()) / "sly_shim_pkg"
|
| 442 |
pkg_training = root / "supervisely" / "nn" / "training"
|
| 443 |
pkg_training.mkdir(parents=True, exist_ok=True)
|
| 444 |
|
|
|
|
| 445 |
for p in [root / "supervisely", root / "supervisely" / "nn", pkg_training]:
|
| 446 |
init_file = p / "__init__.py"
|
| 447 |
if not init_file.exists():
|
| 448 |
init_file.write_text("")
|
| 449 |
|
|
|
|
| 450 |
(pkg_training / "__init__.py").write_text(textwrap.dedent("""
|
|
|
|
| 451 |
class _TrainLogger:
|
| 452 |
def __init__(self): pass
|
| 453 |
def reset(self): pass
|
|
|
|
| 487 |
# --- include absolutizer ------------------------------------------------------
|
| 488 |
def _absify_any_paths_deep(node, base_dir, include_keys=("base", "_base_", "BASE", "BASE_YAML",
|
| 489 |
"includes", "include", "BASES", "__include__")):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
def _absify(s: str) -> str:
|
| 491 |
if os.path.isabs(s):
|
| 492 |
return s
|
|
|
|
| 516 |
|
| 517 |
# --- NEW: safe model field setters --------------------------------------------
|
| 518 |
def _set_num_classes_safely(cfg: dict, n: int):
|
|
|
|
|
|
|
|
|
|
| 519 |
def set_num_classes(node):
|
| 520 |
if not isinstance(node, dict):
|
| 521 |
return False
|
|
|
|
| 540 |
block["num_classes"] = int(n)
|
| 541 |
return
|
| 542 |
|
| 543 |
+
cfg["num_classes"] = int(n)
|
| 544 |
|
| 545 |
def _maybe_set_model_field(cfg: dict, key: str, value):
|
|
|
|
|
|
|
|
|
|
| 546 |
m = cfg.get("model", None)
|
| 547 |
if isinstance(m, dict):
|
| 548 |
m[key] = value
|
|
|
|
| 550 |
if isinstance(m, str) and isinstance(cfg.get(m), dict):
|
| 551 |
cfg[m][key] = value
|
| 552 |
return
|
| 553 |
+
cfg[key] = value
|
| 554 |
|
| 555 |
# --- CRITICAL: dataset override + include cleanup + sync_bn off ---------------
|
| 556 |
def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
|
|
|
|
| 565 |
cfg = yaml.safe_load(f)
|
| 566 |
_absify_any_paths_deep(cfg, template_dir)
|
| 567 |
|
| 568 |
+
# Ensure the runtime knows which Python module hosts builders
|
| 569 |
+
cfg["task"] = cfg.get("task", "detection")
|
| 570 |
+
cfg["_pymodule"] = cfg.get("_pymodule", "rtdetrv2_pytorch.src") # <= HINT for loader
|
| 571 |
+
|
| 572 |
# Disable SyncBN for single GPU/CPU runs
|
| 573 |
cfg["sync_bn"] = False
|
| 574 |
|
|
|
|
| 590 |
"out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
|
| 591 |
}
|
| 592 |
|
|
|
|
| 593 |
def ensure_and_patch_dl(dl_key, img_key, json_key, default_shuffle):
|
| 594 |
block = cfg.get(dl_key)
|
| 595 |
if not isinstance(block, dict):
|
|
|
|
| 616 |
}
|
| 617 |
cfg[dl_key] = block
|
| 618 |
|
|
|
|
| 619 |
ds = block.get("dataset", {})
|
| 620 |
if isinstance(ds, dict):
|
| 621 |
ds["img_folder"] = paths[img_key]
|
|
|
|
| 633 |
|
| 634 |
ensure_and_patch_dl("train_dataloader", "train_img", "train_json", default_shuffle=True)
|
| 635 |
ensure_and_patch_dl("val_dataloader", "val_img", "val_json", default_shuffle=False)
|
|
|
|
|
|
|
| 636 |
|
|
|
|
| 637 |
_set_num_classes_safely(cfg, int(class_count))
|
| 638 |
|
|
|
|
| 639 |
applied_epoch = False
|
| 640 |
for key in ("epoches", "max_epoch", "epochs", "num_epochs"):
|
| 641 |
if key in cfg:
|
|
|
|
| 652 |
cfg["epoches"] = int(epochs)
|
| 653 |
cfg["input_size"] = int(imgsz)
|
| 654 |
|
|
|
|
| 655 |
if "solver" not in cfg or not isinstance(cfg["solver"], dict):
|
| 656 |
cfg["solver"] = {}
|
| 657 |
sol = cfg["solver"]
|
|
|
|
| 665 |
if "train_dataloader" not in cfg or not isinstance(cfg["train_dataloader"], dict):
|
| 666 |
sol["batch_size"] = int(batch)
|
| 667 |
|
|
|
|
| 668 |
if "output_dir" in cfg:
|
| 669 |
cfg["output_dir"] = paths["out_dir"]
|
| 670 |
else:
|
| 671 |
sol["output_dir"] = paths["out_dir"]
|
| 672 |
|
|
|
|
| 673 |
if pretrained_path:
|
| 674 |
p = os.path.abspath(pretrained_path)
|
| 675 |
_maybe_set_model_field(cfg, "pretrain", p)
|
| 676 |
_maybe_set_model_field(cfg, "pretrained", p)
|
| 677 |
|
|
|
|
| 678 |
cfg_out_dir = os.path.join(template_dir, "generated")
|
| 679 |
os.makedirs(cfg_out_dir, exist_ok=True)
|
| 680 |
out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
|
| 681 |
|
|
|
|
| 682 |
class _NoFlowDumper(yaml.SafeDumper): ...
|
| 683 |
def _repr_list_block(dumper, data):
|
| 684 |
return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=False)
|
|
|
|
| 804 |
out_dir = os.path.abspath(os.path.join("runs", "train", run_name))
|
| 805 |
os.makedirs(out_dir, exist_ok=True)
|
| 806 |
|
|
|
|
| 807 |
pretrained_path = _ensure_checkpoint(model_key, out_dir)
|
| 808 |
|
| 809 |
cfg_path = patch_base_config(
|
|
|
|
| 826 |
def run_train():
|
| 827 |
try:
|
| 828 |
env = os.environ.copy()
|
| 829 |
+
# Make sure repo code can be imported
|
| 830 |
env["PYTHONPATH"] = os.pathsep.join(filter(None, [
|
| 831 |
PY_IMPL_DIR, REPO_DIR, env.get("PYTHONPATH", "")
|
| 832 |
]))
|
| 833 |
+
# Put our shim first so supervisely import never breaks
|
| 834 |
shim_root = _install_supervisely_logger_shim()
|
| 835 |
env["PYTHONPATH"] = os.pathsep.join([shim_root, env["PYTHONPATH"]])
|
| 836 |
env.setdefault("WANDB_DISABLED", "true")
|
| 837 |
+
# Provide a secondary hint for some config loaders
|
| 838 |
+
env.setdefault("RTDETR_PYMODULE", "rtdetrv2_pytorch.src")
|
| 839 |
+
|
| 840 |
proc = subprocess.Popen(cmd, cwd=os.path.dirname(train_script),
|
| 841 |
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
| 842 |
bufsize=1, text=True, env=env)
|
|
|
|
| 857 |
if line.startswith("__EXITCODE__"):
|
| 858 |
code = int(line.split(":", 1)[1])
|
| 859 |
if code != 0:
|
| 860 |
+
head = "\n".join(first_lines[-200:])
|
| 861 |
raise gr.Error(f"Training exited with code {code}.\nLast output:\n{head or 'No logs captured.'}")
|
| 862 |
break
|
| 863 |
if line.startswith("__ERROR__"):
|
| 864 |
raise gr.Error(f"Training failed: {line.split(':', 1)[1]}")
|
| 865 |
|
| 866 |
+
if len(first_lines) < 2000:
|
| 867 |
first_lines.append(line)
|
| 868 |
log_tail.append(line)
|
| 869 |
log_tail = log_tail[-40:]
|
|
|
|
| 877 |
pass
|
| 878 |
progress(min(max(last_epoch / max(1, total_epochs), 0.0), 1.0), desc=f"Epoch {last_epoch}/{total_epochs}")
|
| 879 |
|
|
|
|
| 880 |
line_no += 1
|
| 881 |
fig1 = fig2 = None
|
| 882 |
if line_no % 80 == 0:
|