wuhp commited on
Commit
985f19d
·
verified ·
1 Parent(s): b0cabfb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -71
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py — Rolo: RT-DETRv2-only (Supervisely) trainer with auto COCO conversion & config
2
  import os, sys, subprocess, shutil, stat, yaml, gradio as gr, re, random, logging, requests, json, base64, time
3
  from urllib.parse import urlparse
4
  from glob import glob
@@ -336,71 +336,141 @@ def find_training_script(repo_root):
336
 
337
  def find_model_config_template(model_key):
338
  """
339
- Find a base config YAML in the repo that matches the chosen model key.
340
- We look under any configs directory for a yaml containing 'rtdetrv2' and the model key.
 
 
 
 
 
341
  """
 
 
 
 
 
 
342
  yamls = glob(os.path.join(REPO_DIR, "**", "*.yml"), recursive=True) + \
343
  glob(os.path.join(REPO_DIR, "**", "*.yaml"), recursive=True)
344
- # prioritize files with both rtdetrv2 and the exact key in the name
345
  def score(p):
346
- n = os.path.basename(p).lower()
347
  s = 0
348
- if "rtdetrv2" in n: s += 2
349
- if model_key in n: s += 3
350
- if "coco" in n: s += 1
 
 
 
351
  return -s, len(p)
 
352
  yamls.sort(key=score)
353
  return yamls[0] if yamls else None
354
 
355
- def write_custom_config(base_cfg_path, merged_dir, class_count, model_key, run_name, epochs, batch, imgsz, lr, optimizer):
 
356
  """
357
- Generate a small override config that points to our COCO jsons and sets key hyperparams.
358
- This YAML gets merged by the repo's config system if it supports '_base_' includes;
359
- otherwise, it still provides reasonable keys many RT-DETRv2 forks accept.
360
  """
361
- ann_dir = os.path.join(merged_dir, "annotations")
362
- cfg_out_dir = os.path.join("generated_configs")
363
- os.makedirs(cfg_out_dir, exist_ok=True)
364
- out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
 
365
 
366
- # Try a broadly compatible structure (kept simple on purpose)
367
- override = {
368
- "_base_": os.path.relpath(base_cfg_path, start=cfg_out_dir) if base_cfg_path else None,
369
- "model": {"name": model_key, "num_classes": int(class_count)},
370
- "input_size": int(imgsz),
371
- "max_epoch": int(epochs),
372
- "solver": {
373
- "base_lr": float(lr),
374
- "optimizer": str(optimizer).lower(), # "adam", "adamw", "sgd"
375
- "batch_size": int(batch),
376
- },
377
- "dataset": {
378
- "train": {
379
- "name": "coco",
380
- "ann_file": os.path.abspath(os.path.join(ann_dir, "instances_train.json")),
381
- "img_prefix": os.path.abspath(os.path.join(merged_dir, "train", "images")),
382
- },
383
- "val": {
384
- "name": "coco",
385
- "ann_file": os.path.abspath(os.path.join(ann_dir, "instances_val.json")),
386
- "img_prefix": os.path.abspath(os.path.join(merged_dir, "valid", "images")),
387
- },
388
- "test": {
389
- "name": "coco",
390
- "ann_file": os.path.abspath(os.path.join(ann_dir, "instances_test.json")),
391
- "img_prefix": os.path.abspath(os.path.join(merged_dir, "test", "images")),
392
- },
393
- },
394
- "output_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
395
- # some forks use these dataloader keys:
396
- "train_dataloader": {"batch_size": int(batch)},
397
- "val_dataloader": {"batch_size": int(batch)},
398
  }
399
- # drop None values cleanly
400
- if override["_base_"] is None:
401
- del override["_base_"]
402
 
403
- with open(out_path, "w") as f: yaml.safe_dump(override, f, sort_keys=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  return out_path
405
 
406
  def find_best_checkpoint(out_dir):
@@ -500,26 +570,27 @@ def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
500
  def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr, opt, progress=gr.Progress()):
501
  if not dataset_path: raise gr.Error("Finalize a dataset in Tab 2 before training.")
502
 
503
- # 1) find training script (nested-safe)
504
  train_script = find_training_script(REPO_DIR)
505
  if not train_script:
506
  raise gr.Error("RT-DETRv2 training script not found inside the repo (looked for **/tools/train.py).")
507
 
508
- # 2) pick a model config template from repo (best effort)
509
  base_cfg = find_model_config_template(model_key)
 
 
510
 
511
- # 3) read class names from our merged data.yaml to set num_classes + produce COCO JSONs
512
  data_yaml = os.path.join(dataset_path, "data.yaml")
513
  with open(data_yaml, "r") as f: dy = yaml.safe_load(f)
514
  class_names = [str(x) for x in dy.get("names", [])]
515
- ann_dir = make_coco_annotations(dataset_path, class_names)
516
 
517
- # 4) write a small override config that points to our data and injects hyper-params
518
- cfg_path = write_custom_config(
519
  base_cfg_path=base_cfg,
520
  merged_dir=dataset_path,
521
  class_count=len(class_names),
522
- model_key=model_key,
523
  run_name=run_name,
524
  epochs=epochs,
525
  batch=batch,
@@ -531,18 +602,20 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
531
  out_dir = os.path.abspath(os.path.join("runs", "train", run_name))
532
  os.makedirs(out_dir, exist_ok=True)
533
 
534
- # 5) build & run the command (single-GPU by default, no manual CLI edits)
535
  cmd = [sys.executable, train_script, "-c", os.path.abspath(cfg_path)]
536
- # many forks accept optional flags; pass safe ones if present
537
- if "--use-amp" in open(train_script).read(): # cheap check
538
- cmd += ["--use-amp"]
539
  logging.info(f"Training command: {' '.join(cmd)}")
540
 
541
  q = Queue()
542
  def run_train():
543
  try:
544
  env = os.environ.copy()
545
- env["PYTHONPATH"] = REPO_DIR + os.pathsep + env.get("PYTHONPATH", "")
 
 
 
 
 
546
  proc = subprocess.Popen(cmd, cwd=os.path.dirname(train_script),
547
  stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
548
  bufsize=1, text=True, env=env)
@@ -555,25 +628,26 @@ def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr
555
  Thread(target=run_train, daemon=True).start()
556
 
557
  log_tail, last_epoch, total_epochs = [], 0, int(epochs)
 
558
  while True:
559
  line = q.get()
560
  if line.startswith("__EXITCODE__"):
561
  code = int(line.split(":",1)[1])
562
- if code != 0: raise gr.Error(f"Training exited with code {code}. See logs above.")
 
 
563
  break
564
  if line.startswith("__ERROR__"):
565
  raise gr.Error(f"Training failed: {line.split(':',1)[1]}")
566
 
567
- log_tail.append(line)
568
- log_tail = log_tail[-30:]
569
 
570
  m = re.search(r"[Ee]poch\s+(\d+)\s*/\s*(\d+)", line)
571
  if m:
572
  try:
573
- last_epoch = int(m.group(1))
574
- total_epochs = max(total_epochs, int(m.group(2)))
575
- except Exception:
576
- pass
577
  progress(min(max(last_epoch / max(1,total_epochs),0.0),1.0), desc=f"Epoch {last_epoch}/{total_epochs}")
578
 
579
  fig1 = plt.figure(); plt.title("Loss (see logs)")
 
1
+ # app.py — Rolo: RT-DETRv2-only (Supervisely) trainer with auto COCO conversion & safe config patching
2
  import os, sys, subprocess, shutil, stat, yaml, gradio as gr, re, random, logging, requests, json, base64, time
3
  from urllib.parse import urlparse
4
  from glob import glob
 
336
 
337
  def find_model_config_template(model_key):
338
  """
339
+ Choose a native RT-DETRv2 config YAML from the Supervisely repo.
340
+
341
+ Heuristics:
342
+ - rtdetrv2_s -> r18 (Small)
343
+ - rtdetrv2_l -> r50 (Large)
344
+ - rtdetrv2_x -> r101 (X-Large)
345
+ Prefer files under rtdetrv2_pytorch/**/config(s) and with 'coco' in name.
346
  """
347
+ want_tokens = {
348
+ "rtdetrv2_s": ["rtdetrv2", "r18", "coco"],
349
+ "rtdetrv2_l": ["rtdetrv2", "r50", "coco"],
350
+ "rtdetrv2_x": ["rtdetrv2", "r101", "coco"],
351
+ }.get(model_key, ["rtdetrv2", "r18", "coco"])
352
+
353
  yamls = glob(os.path.join(REPO_DIR, "**", "*.yml"), recursive=True) + \
354
  glob(os.path.join(REPO_DIR, "**", "*.yaml"), recursive=True)
355
+
356
  def score(p):
357
+ pl = p.lower()
358
  s = 0
359
+ if "/rtdetrv2_pytorch/" in pl: s += 4
360
+ if "/config" in pl: s += 3
361
+ for token in want_tokens:
362
+ if token in os.path.basename(pl): s += 3
363
+ if token in pl: s += 2
364
+ if "coco" in pl: s += 1
365
  return -s, len(p)
366
+
367
  yamls.sort(key=score)
368
  return yamls[0] if yamls else None
369
 
370
+ def patch_base_config(base_cfg_path, merged_dir, class_count, run_name,
371
+ epochs, batch, imgsz, lr, optimizer):
372
  """
373
+ Load the chosen repo config and patch only the keys that already exist.
374
+ This avoids schema mismatches between forks.
 
375
  """
376
+ if not base_cfg_path or not os.path.exists(base_cfg_path):
377
+ raise gr.Error("Could not locate a model config inside the RT-DETRv2 repo.")
378
+
379
+ with open(base_cfg_path, "r") as f:
380
+ cfg = yaml.safe_load(f)
381
 
382
+ ann_dir = os.path.join(merged_dir, "annotations")
383
+ paths = {
384
+ "train_json": os.path.abspath(os.path.join(ann_dir, "instances_train.json")),
385
+ "val_json": os.path.abspath(os.path.join(ann_dir, "instances_val.json")),
386
+ "test_json": os.path.abspath(os.path.join(ann_dir, "instances_test.json")),
387
+ "train_img": os.path.abspath(os.path.join(merged_dir, "train", "images")),
388
+ "val_img": os.path.abspath(os.path.join(merged_dir, "valid", "images")), # Roboflow uses 'valid'
389
+ "test_img": os.path.abspath(os.path.join(merged_dir, "test", "images")),
390
+ "out_dir": os.path.abspath(os.path.join("runs", "train", run_name)),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  }
 
 
 
392
 
393
+ # --- dataset block --------------------------------------------------------
394
+ for root_key in ["dataset", "data"]:
395
+ if root_key in cfg and isinstance(cfg[root_key], dict):
396
+ ds = cfg[root_key]
397
+ for split, jf, ip in [
398
+ ("train", "train_json", "train_img"),
399
+ ("val", "val_json", "val_img"),
400
+ ("test", "test_json", "test_img"),
401
+ ]:
402
+ if split in ds and isinstance(ds[split], dict):
403
+ ds[split]["name"] = ds[split].get("name", "coco")
404
+ # Common key variants across forks:
405
+ for k in ["ann_file", "ann_path", "annotation", "annotations"]:
406
+ if k in ds[split] or k in ["ann_file", "ann_path"]:
407
+ ds[split][k] = paths[jf]
408
+ break
409
+ for k in ["img_prefix", "img_dir", "image_root", "data_root"]:
410
+ if k in ds[split] or k in ["img_prefix", "img_dir"]:
411
+ ds[split][k] = paths[ip]
412
+ break
413
+
414
+ # --- num_classes ----------------------------------------------------------
415
+ def set_num_classes(node, n):
416
+ if not isinstance(node, dict): return False
417
+ if "num_classes" in node:
418
+ node["num_classes"] = int(n); return True
419
+ for k, v in node.items():
420
+ if isinstance(v, dict) and set_num_classes(v, n): return True
421
+ return False
422
+
423
+ if "model" in cfg and isinstance(cfg["model"], dict):
424
+ if not set_num_classes(cfg["model"], class_count):
425
+ cfg["model"]["num_classes"] = int(class_count)
426
+ else:
427
+ cfg["model"] = {"num_classes": int(class_count)}
428
+
429
+ # --- epochs / image size --------------------------------------------------
430
+ updated_epoch = False
431
+ for key in ["max_epoch", "epochs", "num_epochs"]:
432
+ if key in cfg:
433
+ cfg[key] = int(epochs); updated_epoch = True; break
434
+ if "solver" in cfg and isinstance(cfg["solver"], dict):
435
+ for key in ["max_epoch", "epochs", "num_epochs"]:
436
+ if key in cfg["solver"]:
437
+ cfg["solver"][key] = int(epochs); updated_epoch = True; break
438
+ if not updated_epoch:
439
+ cfg["max_epoch"] = int(epochs)
440
+
441
+ for key in ["input_size", "img_size", "imgsz"]:
442
+ if key in cfg: cfg[key] = int(imgsz)
443
+ if "input_size" not in cfg: cfg["input_size"] = int(imgsz)
444
+
445
+ # --- learning rate / optimizer / batch -----------------------------------
446
+ if "solver" not in cfg or not isinstance(cfg["solver"], dict):
447
+ cfg["solver"] = {}
448
+ sol = cfg["solver"]
449
+ for key in ["base_lr", "lr", "learning_rate"]:
450
+ if key in sol:
451
+ sol[key] = float(lr); break
452
+ else:
453
+ sol["base_lr"] = float(lr)
454
+
455
+ sol["optimizer"] = str(optimizer).lower()
456
+
457
+ if "train_dataloader" in cfg and isinstance(cfg["train_dataloader"], dict):
458
+ cfg["train_dataloader"]["batch_size"] = int(batch)
459
+ else:
460
+ sol["batch_size"] = int(batch)
461
+
462
+ # --- output dir -----------------------------------------------------------
463
+ if "output_dir" in cfg:
464
+ cfg["output_dir"] = paths["out_dir"]
465
+ elif "solver" in cfg:
466
+ sol["output_dir"] = paths["out_dir"]
467
+ else:
468
+ cfg["output_dir"] = paths["out_dir"]
469
+
470
+ # --- write patched config -------------------------------------------------
471
+ cfg_out_dir = os.path.join("generated_configs"); os.makedirs(cfg_out_dir, exist_ok=True)
472
+ out_path = os.path.join(cfg_out_dir, f"{run_name}.yaml")
473
+ with open(out_path, "w") as f: yaml.safe_dump(cfg, f, sort_keys=False)
474
  return out_path
475
 
476
  def find_best_checkpoint(out_dir):
 
570
  def training_handler(dataset_path, model_key, run_name, epochs, batch, imgsz, lr, opt, progress=gr.Progress()):
571
  if not dataset_path: raise gr.Error("Finalize a dataset in Tab 2 before training.")
572
 
573
+ # 1) training script (nested-safe)
574
  train_script = find_training_script(REPO_DIR)
575
  if not train_script:
576
  raise gr.Error("RT-DETRv2 training script not found inside the repo (looked for **/tools/train.py).")
577
 
578
+ # 2) base config = a real model template from the repo
579
  base_cfg = find_model_config_template(model_key)
580
+ if not base_cfg:
581
+ raise gr.Error("Could not find a matching RT-DETRv2 config in the repo (S/L/X).")
582
 
583
+ # 3) read classes + ensure COCO JSONs up to date
584
  data_yaml = os.path.join(dataset_path, "data.yaml")
585
  with open(data_yaml, "r") as f: dy = yaml.safe_load(f)
586
  class_names = [str(x) for x in dy.get("names", [])]
587
+ make_coco_annotations(dataset_path, class_names)
588
 
589
+ # 4) patch the base config safely (no custom schema assumptions)
590
+ cfg_path = patch_base_config(
591
  base_cfg_path=base_cfg,
592
  merged_dir=dataset_path,
593
  class_count=len(class_names),
 
594
  run_name=run_name,
595
  epochs=epochs,
596
  batch=batch,
 
602
  out_dir = os.path.abspath(os.path.join("runs", "train", run_name))
603
  os.makedirs(out_dir, exist_ok=True)
604
 
605
+ # 5) build & run command (no extra flags that might not exist)
606
  cmd = [sys.executable, train_script, "-c", os.path.abspath(cfg_path)]
 
 
 
607
  logging.info(f"Training command: {' '.join(cmd)}")
608
 
609
  q = Queue()
610
  def run_train():
611
  try:
612
  env = os.environ.copy()
613
+ # Ensure both repo root and pytorch impl are on PYTHONPATH
614
+ env["PYTHONPATH"] = os.pathsep.join(filter(None, [
615
+ PY_IMPL_DIR, REPO_DIR, env.get("PYTHONPATH", "")
616
+ ]))
617
+ # Disable wandb in Spaces by default
618
+ env.setdefault("WANDB_DISABLED", "true")
619
  proc = subprocess.Popen(cmd, cwd=os.path.dirname(train_script),
620
  stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
621
  bufsize=1, text=True, env=env)
 
628
  Thread(target=run_train, daemon=True).start()
629
 
630
  log_tail, last_epoch, total_epochs = [], 0, int(epochs)
631
+ first_lines = [] # capture early errors for nicer message
632
  while True:
633
  line = q.get()
634
  if line.startswith("__EXITCODE__"):
635
  code = int(line.split(":",1)[1])
636
+ if code != 0:
637
+ head = "\n".join(first_lines[:60])
638
+ raise gr.Error(f"Training exited with code {code}.\nLast output:\n{head or 'No logs captured.'}")
639
  break
640
  if line.startswith("__ERROR__"):
641
  raise gr.Error(f"Training failed: {line.split(':',1)[1]}")
642
 
643
+ if len(first_lines) < 120: first_lines.append(line)
644
+ log_tail.append(line); log_tail = log_tail[-40:]
645
 
646
  m = re.search(r"[Ee]poch\s+(\d+)\s*/\s*(\d+)", line)
647
  if m:
648
  try:
649
+ last_epoch = int(m.group(1)); total_epochs = max(total_epochs, int(m.group(2)))
650
+ except Exception: pass
 
 
651
  progress(min(max(last_epoch / max(1,total_epochs),0.0),1.0), desc=f"Epoch {last_epoch}/{total_epochs}")
652
 
653
  fig1 = plt.figure(); plt.title("Loss (see logs)")