wuhp commited on
Commit
f735495
·
verified ·
1 Parent(s): 0257e16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +245 -201
app.py CHANGED
@@ -3,7 +3,6 @@ import shutil
3
  import stat
4
  import yaml
5
  import gradio as gr
6
- from ultralytics import YOLO # Ultralytics RT-DETR runner
7
  from roboflow import Roboflow
8
  import re
9
  from urllib.parse import urlparse
@@ -12,34 +11,31 @@ import logging
12
  import requests
13
  import json
14
  from PIL import Image
15
- import torch
16
  import pandas as pd
17
  import matplotlib.pyplot as plt
18
  from threading import Thread
19
  from queue import Queue
20
  from huggingface_hub import HfApi, HfFolder
21
  import base64
 
 
 
 
22
 
23
  # --- Configuration ---
24
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
25
 
26
- # Hardcode RT-DETR model configurations. All YOLO options are removed.
27
- RTDETR_MODELS = {
28
- "detection": [
29
- {
30
- "filename": "rtdetr-l.pt",
31
- "url": "https://github.com/ultralytics/assets/releases/download/v8.0.0/rtdetr-l.pt",
32
- "description": "RT-DETR Large model (Default)"
33
- },
34
- {
35
- "filename": "rtdetr-x.pt",
36
- "url": "https://github.com/ultralytics/assets/releases/download/v8.0.0/rtdetr-x.pt",
37
- "description": "RT-DETR Extra-Large model."
38
- }
39
- ]
40
- }
41
- DEFAULT_MODEL = "rtdetr-l.pt"
42
 
 
 
 
 
 
 
 
43
 
44
  # ------------------------------
45
  # Utilities
@@ -53,19 +49,18 @@ def handle_remove_readonly(func, path, exc_info):
53
  pass
54
  func(path)
55
 
56
-
57
  _ROBO_URL_RX = re.compile(
58
  r"""
59
  ^(?:
60
- (?:https?://)?(?:universe|app|www)?\.?roboflow\.com/ # Any roboflow host
61
- (?P<ws>[A-Za-z0-9\-_]+)/ # workspace
62
- (?P<proj>[A-Za-z0-9\-_]+)/? # project
63
  (?:
64
- (?:dataset/[^/]+/)? # optional 'dataset/<fmt>/'
65
- (?:v?(?P<ver>\d+))? # optional version 'vN' or 'N'
66
  )?
67
  |
68
- (?P<ws2>[A-Za-z0-9\-_]+)/(?P<proj2>[A-Za-z0-9\-_]+)(?:/(?:v)?(?P<ver2>\d+))? # raw ws/proj[/vN]
69
  )$
70
  """,
71
  re.VERBOSE | re.IGNORECASE
@@ -73,15 +68,14 @@ _ROBO_URL_RX = re.compile(
73
 
74
  def parse_roboflow_url(s: str):
75
  """
76
- Accepts:
77
- - https://universe.roboflow.com/<workspace>/<project>[/vN | /N]
78
- - https://app.roboflow.com/<workspace>/<project>[/vN | /N]
79
- - https://roboflow.com/<workspace>/<project>[/vN | /N]
80
- - raw: <workspace>/<project>[/vN | /N]
81
  Returns: (workspace, project, version_or_None)
82
  """
83
  s = s.strip()
84
- # Fast path: try regex
85
  m = _ROBO_URL_RX.match(s)
86
  if m:
87
  ws = m.group('ws') or m.group('ws2')
@@ -89,14 +83,11 @@ def parse_roboflow_url(s: str):
89
  ver = m.group('ver') or m.group('ver2')
90
  return ws, proj, (int(ver) if ver else None)
91
 
92
- # Fallback: parse like URL and split path
93
  parsed = urlparse(s)
94
  parts = [p for p in parsed.path.strip('/').split('/') if p]
95
  if len(parts) >= 2:
96
- # Try to pull raw version from the 3rd part if it exists
97
  version = None
98
  if len(parts) >= 3:
99
- # Accept 'vN' or 'N'
100
  vpart = parts[2]
101
  if vpart.lower().startswith('v') and vpart[1:].isdigit():
102
  version = int(vpart[1:])
@@ -104,11 +95,9 @@ def parse_roboflow_url(s: str):
104
  version = int(vpart)
105
  return parts[0], parts[1], version
106
 
107
- # Fallback raw "ws/proj" without slashes in URL
108
  if '/' in s and 'roboflow' not in s:
109
  p = s.split('/')
110
  if len(p) >= 2:
111
- # Accept trailing version if present
112
  version = None
113
  if len(p) >= 3:
114
  v = p[2]
@@ -120,7 +109,6 @@ def parse_roboflow_url(s: str):
120
 
121
  return None, None, None
122
 
123
-
124
  def get_latest_version(api_key, workspace, project):
125
  """Gets the latest version number of a Roboflow project."""
126
  try:
@@ -132,15 +120,14 @@ def get_latest_version(api_key, workspace, project):
132
  logging.error(f"Could not get latest version for {workspace}/{project}: {e}")
133
  return None
134
 
135
-
136
- # --- NEW: normalize class names from data.yaml ---
137
  def _extract_class_names(data_yaml):
138
  """
139
- Return a list[str] of class names in index order.
140
- Handles:
141
- - list (possibly containing non-str types)
142
- - dict with numeric keys (e.g., {0: 'cat', 1: 'dog'})
143
- - fallback to ['class_0', ..., f'class_{nc-1}'] if names missing
144
  """
145
  names = data_yaml.get('names', None)
146
 
@@ -150,8 +137,8 @@ def _extract_class_names(data_yaml):
150
  return int(x)
151
  except Exception:
152
  return str(x)
153
- ordered_keys = sorted(names.keys(), key=_k)
154
- names_list = [names[k] for k in ordered_keys]
155
  elif isinstance(names, list):
156
  names_list = names
157
  else:
@@ -164,9 +151,8 @@ def _extract_class_names(data_yaml):
164
 
165
  return [str(x) for x in names_list]
166
 
167
-
168
  def download_dataset(api_key, workspace, project, version):
169
- """Downloads a single dataset from Roboflow (yolov8 format works fine for RT-DETR)."""
170
  try:
171
  rf = Roboflow(api_key=api_key)
172
  proj = rf.workspace(workspace).project(project)
@@ -177,7 +163,6 @@ def download_dataset(api_key, workspace, project, version):
177
  with open(data_yaml_path, 'r') as f:
178
  data_yaml = yaml.safe_load(f)
179
 
180
- # --- UPDATED: use normalized names and optional sanity log ---
181
  class_names = _extract_class_names(data_yaml)
182
  try:
183
  nc = int(data_yaml.get('nc', len(class_names)))
@@ -194,30 +179,25 @@ def download_dataset(api_key, workspace, project, version):
194
  logging.error(f"Failed to download {workspace}/{project}/v{version}: {e}")
195
  return None, [], [], None
196
 
197
-
198
  def label_path_for(img_path: str) -> str:
199
- """Convert .../split/images/file.jpg -> .../split/labels/file.txt in a safe way."""
200
  split_dir = os.path.dirname(os.path.dirname(img_path)) # .../split
201
  base = os.path.splitext(os.path.basename(img_path))[0] + '.txt'
202
  return os.path.join(split_dir, 'labels', base)
203
 
204
-
205
  def gather_class_counts(dataset_info, class_mapping):
206
  """
207
- Count, per final class, how many images contain at least one instance of that class
208
- (counted once per image). class_mapping maps original_name -> final_name.
209
  """
210
  if not dataset_info:
211
  return {}
212
 
213
- final_names = set(class_mapping.values())
214
  counts = {name: 0 for name in final_names}
215
 
216
  for loc, names, splits, _ in dataset_info:
217
- # Map from original idx -> mapped name (or None if removed later)
218
- id_to_name = {}
219
- for idx, n in enumerate(names):
220
- id_to_name[idx] = class_mapping.get(n, None)
221
 
222
  for split in splits:
223
  labels_dir = os.path.join(loc, split, 'labels')
@@ -244,9 +224,8 @@ def gather_class_counts(dataset_info, class_mapping):
244
 
245
  return counts
246
 
247
-
248
  def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=gr.Progress()):
249
- """Core function to merge datasets based on user rules."""
250
  merged_dir = 'rolo_merged_dataset'
251
  if os.path.exists(merged_dir):
252
  shutil.rmtree(merged_dir, onerror=handle_remove_readonly)
@@ -256,12 +235,10 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
256
  os.makedirs(os.path.join(merged_dir, split, 'images'), exist_ok=True)
257
  os.makedirs(os.path.join(merged_dir, split, 'labels'), exist_ok=True)
258
 
259
- # Only classes with positive limits are active
260
- active_classes = [cls for cls, limit in class_limits.items() if limit > 0]
261
- active_classes = sorted(set(active_classes))
262
  final_class_map = {name: i for i, name in enumerate(active_classes)}
263
 
264
- # Collect all candidate images
265
  all_images = []
266
  for loc, _, splits, _ in dataset_info:
267
  for split in splits:
@@ -276,8 +253,6 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
276
  progress(0.2, desc="Selecting images based on limits...")
277
  selected_images = []
278
  current_counts = {cls: 0 for cls in active_classes}
279
-
280
- # Build a quick lookup: source_loc -> names list
281
  loc_to_names = {info[0]: info[1] for info in dataset_info}
282
 
283
  for img_path, split, source_loc in progress.tqdm(all_images, desc="Analyzing images"):
@@ -303,8 +278,6 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
303
 
304
  if not image_classes:
305
  continue
306
-
307
- # Check limits
308
  if any(current_counts[c] >= class_limits[c] for c in image_classes):
309
  continue
310
 
@@ -319,7 +292,6 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
319
  out_lbl = os.path.join(merged_dir, split, 'labels', os.path.basename(lbl_path))
320
  shutil.copy(img_path, out_img)
321
 
322
- # Determine source names by matching the parent dataset root
323
  source_loc = None
324
  for info in dataset_info:
325
  if img_path.startswith(info[0]):
@@ -355,6 +327,68 @@ def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=
355
 
356
  return f"Dataset finalized with {len(selected_images)} images.", os.path.abspath(merged_dir)
357
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
 
359
  # ------------------------------
360
  # Gradio UI Event Handlers
@@ -394,36 +428,29 @@ def load_datasets_handler(api_key, url_file, progress=gr.Progress()):
394
  failures.append((raw, f"DownloadError: {ws}/{proj}/v{ver}"))
395
 
396
  if not dataset_info:
397
- # Show a compact failure report to the UI
398
  msg = "No datasets were loaded successfully.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]])
399
  raise gr.Error(msg)
400
 
401
- # --- UPDATED: ensure all names are strings before sorting
402
  all_names = sorted({str(n) for _, names, _, _ in dataset_info for n in names})
403
  class_map = {name: name for name in all_names}
404
 
405
- # Initial preview uses "keep all" mapping
406
  initial_counts = gather_class_counts(dataset_info, class_map)
407
  df_data = [[name, name, initial_counts.get(name, 0), False] for name in all_names]
408
  status_text = "Datasets loaded successfully."
409
  if failures:
410
  status_text += f" ({len(dataset_info)} OK, {len(failures)} failed; see console logs)."
411
 
412
- return status_text, dataset_info, gr.DataFrame.update(
 
413
  value=pd.DataFrame(df_data, columns=["Original Name", "Rename To", "Max Images", "Remove"])
414
  )
415
 
416
-
417
  def update_class_counts_handler(class_df, dataset_info):
418
- """
419
- Provides live feedback on class counts as the user edits the DataFrame.
420
- We compute a mapping of original -> final (or None if removed), then count images
421
- for each final name.
422
- """
423
  if class_df is None or not dataset_info:
424
  return None
425
 
426
- # Build mapping original_name -> final_name or None if removed
427
  class_df = pd.DataFrame(class_df)
428
  mapping = {}
429
  for _, row in class_df.iterrows():
@@ -433,14 +460,11 @@ def update_class_counts_handler(class_df, dataset_info):
433
  else:
434
  mapping[orig] = row["Rename To"]
435
 
436
- # Build final set
437
  final_names = sorted(set(v for v in mapping.values() if v))
438
  counts = {k: 0 for k in final_names}
439
 
440
  for loc, names, splits, _ in dataset_info:
441
- id_to_final = {}
442
- for idx, n in enumerate(names):
443
- id_to_final[idx] = mapping.get(n, None)
444
 
445
  for split in splits:
446
  labels_dir = os.path.join(loc, split, 'labels')
@@ -468,15 +492,13 @@ def update_class_counts_handler(class_df, dataset_info):
468
  summary_df = pd.DataFrame(list(counts.items()), columns=["Final Class Name", "Est. Total Images"])
469
  return summary_df
470
 
471
-
472
  def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
473
- """Handles the 'Finalize' button click."""
474
  if not dataset_info:
475
  raise gr.Error("Load datasets first in Tab 1.")
476
  if class_df is None:
477
  raise gr.Error("Class data is missing.")
478
 
479
- # Mapping and limits
480
  class_df = pd.DataFrame(class_df)
481
  class_mapping = {}
482
  class_limits = {}
@@ -486,112 +508,109 @@ def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
486
  continue
487
  final_name = row["Rename To"]
488
  class_mapping[orig] = final_name
489
- # Sum limits for final_name over any merged originals
490
  class_limits[final_name] = class_limits.get(final_name, 0) + int(row["Max Images"])
491
 
492
  status, path = finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress)
493
  return status, path
494
 
495
-
496
- def training_handler(dataset_path, model_filename, run_name, epochs, batch, imgsz, lr, opt, progress=gr.Progress()):
497
- """Handles the training process with live feedback."""
 
 
 
498
  if not dataset_path:
499
  raise gr.Error("Finalize a dataset in Tab 2 before training.")
500
 
501
- # Ultralytics expects device string, e.g. '0' or 'cpu'
502
- device_str = "0" if torch.cuda.is_available() else "cpu"
503
-
504
- metrics_queue = Queue()
505
-
506
- def on_epoch_end(trainer):
507
- # Be defensive about metric keys
508
- m = trainer.metrics or {}
509
- metrics_queue.put({
510
- 'epoch': (trainer.epoch or 0) + 1,
511
- 'train_loss': m.get('train/loss') or m.get('loss'),
512
- 'val_loss': m.get('val/loss'),
513
- 'mAP50': m.get('metrics/mAP50(B)') or m.get('metrics/mAP50'),
514
- 'mAP50_95': m.get('metrics/mAP50-95(B)') or m.get('metrics/mAP50-95')
515
- })
516
-
517
- def train_thread_func():
518
- try:
519
- model_url = next(m['url'] for m in RTDETR_MODELS['detection'] if m['filename'] == model_filename)
520
- weights_path = os.path.join('pretrained_models', model_filename)
521
- if not os.path.exists(weights_path):
522
- os.makedirs('pretrained_models', exist_ok=True)
523
- r = requests.get(model_url, stream=True, timeout=60)
524
- r.raise_for_status()
525
- with open(weights_path, 'wb') as f:
526
- for chunk in r.iter_content(chunk_size=8192):
527
- f.write(chunk)
528
-
529
- model = YOLO(weights_path)
530
- model.add_callback("on_train_epoch_end", on_epoch_end)
531
-
532
- model.train(
533
- data=os.path.join(dataset_path, 'data.yaml'),
534
- epochs=int(epochs),
535
- batch=int(batch),
536
- imgsz=int(imgsz),
537
- lr0=float(lr),
538
- optimizer=str(opt),
539
- project='runs/train',
540
- name=str(run_name),
541
- exist_ok=True,
542
- device=device_str
543
- )
544
- metrics_queue.put("done")
545
- except Exception as e:
546
- logging.exception("Training thread error")
547
- metrics_queue.put(f"error: {e}")
548
 
549
- Thread(target=train_thread_func, daemon=True).start()
 
 
 
 
 
 
 
 
 
 
 
550
 
 
551
  history = {k: [] for k in ['epoch', 'train_loss', 'val_loss', 'mAP50', 'mAP50_95']}
552
- while True:
553
- item = metrics_queue.get()
554
- if isinstance(item, str):
555
- if item == "done":
556
- break
557
- if item.startswith("error"):
558
- raise gr.Error(f"Training failed: {item}")
559
-
560
- # Append metrics
561
- for key in ['epoch', 'train_loss', 'val_loss', 'mAP50', 'mAP50_95']:
562
- val = item.get(key, None)
563
- if val is not None:
564
- history[key].append(val)
565
-
566
- current_epoch = history['epoch'][-1] if history['epoch'] else 0
567
- total_epochs = int(epochs)
568
- frac = min(max(current_epoch / max(1, total_epochs), 0.0), 1.0)
569
- progress(frac, desc=f"Epoch {current_epoch}/{total_epochs}")
570
-
571
- # Plot Loss
572
- fig_loss = plt.figure()
573
- ax_loss = fig_loss.add_subplot(111)
574
- ax_loss.plot(history['epoch'], history['train_loss'], "o-", label='Train Loss')
575
- ax_loss.plot(history['epoch'], history['val_loss'], "o-", label='Val Loss')
576
- ax_loss.legend()
577
- ax_loss.set_title("Loss")
578
-
579
- # Plot mAP
580
- fig_map = plt.figure()
581
- ax_map = fig_map.add_subplot(111)
582
- ax_map.plot(history['epoch'], history['mAP50'], "o-", label='mAP@0.5')
583
- ax_map.plot(history['epoch'], history['mAP50_95'], "o-", label='mAP@0.5:0.95')
584
- ax_map.legend()
585
- ax_map.set_title("mAP")
586
-
587
- yield f"Epoch {current_epoch}/{total_epochs} complete.", fig_loss, fig_map, None
588
-
589
- final_path = os.path.join('runs', 'train', str(run_name), 'weights', 'best.pt')
590
- if not os.path.exists(final_path):
591
- raise gr.Error("Training finished, but 'best.pt' was not found.")
592
-
593
- yield "Training complete!", None, None, gr.File.update(value=final_path, visible=True)
594
-
 
 
 
 
595
 
596
  def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr.Progress()):
597
  """Handles model upload to Hugging Face and GitHub."""
@@ -649,12 +668,11 @@ def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr
649
  progress(1)
650
  return hf_status, gh_status
651
 
652
-
653
  # ------------------------------
654
  # Gradio UI
655
  # ------------------------------
656
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
657
- gr.Markdown("# Rolo: A Dedicated RT-DETR Training Dashboard")
658
 
659
  # State variables
660
  dataset_info_state = gr.State([])
@@ -689,38 +707,54 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
689
  finalize_status = gr.Textbox(label="Status", interactive=False)
690
 
691
  with gr.TabItem("3. Configure & Train"):
692
- gr.Markdown("### Set Hyperparameters and Train the RT-DETR Model")
693
  with gr.Row():
694
  with gr.Column(scale=1):
695
- model_file_dd = gr.Dropdown(
696
- label="Select Pre-Trained RT-DETR Model",
697
- choices=[m["filename"] for m in RTDETR_MODELS["detection"]],
698
  value=DEFAULT_MODEL
699
  )
700
- run_name_tb = gr.Textbox(label="Run Name", value="rtdetr_run_1")
701
  epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
702
- batch_sl = gr.Slider(1, 32, 8, step=1, label="Batch Size")
703
  imgsz_num = gr.Number(label="Image Size", value=640)
704
  lr_num = gr.Number(label="Learning Rate", value=0.001)
705
- opt_dd = gr.Dropdown(["Adam", "AdamW", "SGD"], value="Adam", label="Optimizer")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706
  train_btn = gr.Button("Start Training", variant="primary")
707
  with gr.Column(scale=2):
708
- train_status = gr.Textbox(label="Live Status", interactive=False)
709
  loss_plot = gr.Plot(label="Loss Curves")
710
  map_plot = gr.Plot(label="mAP Curves")
711
- final_model_file = gr.File(label="Download Trained Model (best.pt)", interactive=False, visible=False)
712
 
713
  with gr.TabItem("4. Upload Model"):
714
- gr.Markdown("### Upload Your Trained Model\nAfter training, you can upload the `best.pt` file to Hugging Face and/or GitHub.")
715
  with gr.Row():
716
  with gr.Column():
717
  gr.Markdown("#### Hugging Face")
718
  hf_token = gr.Textbox(label="Hugging Face API Token", type="password")
719
- hf_repo = gr.Textbox(label="Hugging Face Repo ID", placeholder="e.g., username/my-rtdetr-model")
720
  with gr.Column():
721
  gr.Markdown("#### GitHub")
722
  gh_token = gr.Textbox(label="GitHub Personal Access Token", type="password")
723
- gh_repo = gr.Textbox(label="GitHub Repo", placeholder="e.g., username/my-rtdetr-repo")
724
  upload_btn = gr.Button("Upload Model", variant="primary")
725
  with gr.Row():
726
  hf_status = gr.Textbox(label="Hugging Face Status", interactive=False)
@@ -743,8 +777,19 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
743
  outputs=[finalize_status, final_dataset_path_state]
744
  )
745
  train_btn.click(
746
- fn=training_handler,
747
- inputs=[final_dataset_path_state, model_file_dd, run_name_tb, epochs_sl, batch_sl, imgsz_num, lr_num, opt_dd],
 
 
 
 
 
 
 
 
 
 
 
748
  outputs=[train_status, loss_plot, map_plot, final_model_file]
749
  )
750
  upload_btn.click(
@@ -754,6 +799,5 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
754
  )
755
 
756
  if __name__ == "__main__":
757
- # Tip: silence Ultralytics settings warning by setting env var:
758
- # export YOLO_CONFIG_DIR=/tmp/Ultralytics
759
  app.launch(debug=True)
 
3
  import stat
4
  import yaml
5
  import gradio as gr
 
6
  from roboflow import Roboflow
7
  import re
8
  from urllib.parse import urlparse
 
11
  import requests
12
  import json
13
  from PIL import Image
 
14
  import pandas as pd
15
  import matplotlib.pyplot as plt
16
  from threading import Thread
17
  from queue import Queue
18
  from huggingface_hub import HfApi, HfFolder
19
  import base64
20
+ import subprocess
21
+ import sys
22
+ import time
23
+ import glob
24
 
25
  # --- Configuration ---
26
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
27
 
28
+ # Defaults for RT-DETRv2 (Supervisely ecosystem) integration
29
+ RTDETRV2_REPO_URL = "https://github.com/supervisely-ecosystem/RT-DETRv2"
30
+ DEFAULT_REPO_DIR = os.path.join("third_party", "rtdetrv2")
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ # You can still offer "model size" choices to hint the user which config to use,
33
+ # but the actual command is controlled by the template.
34
+ RTDETRV2_MODELS = [
35
+ "rtdetrv2-l-640", # label only; adapt your command template to use real config/weights
36
+ "rtdetrv2-x-640"
37
+ ]
38
+ DEFAULT_MODEL = RTDETRV2_MODELS[0]
39
 
40
  # ------------------------------
41
  # Utilities
 
49
  pass
50
  func(path)
51
 
 
52
  _ROBO_URL_RX = re.compile(
53
  r"""
54
  ^(?:
55
+ (?:https?://)?(?:universe|app|www)?\.?roboflow\.com/
56
+ (?P<ws>[A-Za-z0-9\-_]+)/
57
+ (?P<proj>[A-Za-z0-9\-_]+)/?
58
  (?:
59
+ (?:dataset/[^/]+/)?
60
+ (?:v?(?P<ver>\d+))?
61
  )?
62
  |
63
+ (?P<ws2>[A-Za-z0-9\-_]+)/(?P<proj2>[A-Za-z0-9\-_]+)(?:/(?:v)?(?P<ver2>\d+))?
64
  )$
65
  """,
66
  re.VERBOSE | re.IGNORECASE
 
68
 
69
  def parse_roboflow_url(s: str):
70
  """
71
+ Support:
72
+ - https://universe.roboflow.com/<workspace>/<project>[/vN]
73
+ - https://app.roboflow.com/<workspace>/<project>[/vN]
74
+ - https://roboflow.com/<workspace>/<project>[/vN]
75
+ - raw: <workspace>/<project>[/vN]
76
  Returns: (workspace, project, version_or_None)
77
  """
78
  s = s.strip()
 
79
  m = _ROBO_URL_RX.match(s)
80
  if m:
81
  ws = m.group('ws') or m.group('ws2')
 
83
  ver = m.group('ver') or m.group('ver2')
84
  return ws, proj, (int(ver) if ver else None)
85
 
 
86
  parsed = urlparse(s)
87
  parts = [p for p in parsed.path.strip('/').split('/') if p]
88
  if len(parts) >= 2:
 
89
  version = None
90
  if len(parts) >= 3:
 
91
  vpart = parts[2]
92
  if vpart.lower().startswith('v') and vpart[1:].isdigit():
93
  version = int(vpart[1:])
 
95
  version = int(vpart)
96
  return parts[0], parts[1], version
97
 
 
98
  if '/' in s and 'roboflow' not in s:
99
  p = s.split('/')
100
  if len(p) >= 2:
 
101
  version = None
102
  if len(p) >= 3:
103
  v = p[2]
 
109
 
110
  return None, None, None
111
 
 
112
  def get_latest_version(api_key, workspace, project):
113
  """Gets the latest version number of a Roboflow project."""
114
  try:
 
120
  logging.error(f"Could not get latest version for {workspace}/{project}: {e}")
121
  return None
122
 
123
+ # --- Normalize class names from data.yaml ---
 
124
  def _extract_class_names(data_yaml):
125
  """
126
+ Return list[str] of class names in index order.
127
+ Supports:
128
+ - list
129
+ - dict with numeric keys {0:'cat',1:'dog'}
130
+ - fallback to ['class_0', ...]
131
  """
132
  names = data_yaml.get('names', None)
133
 
 
137
  return int(x)
138
  except Exception:
139
  return str(x)
140
+ ordered = sorted(names.keys(), key=_k)
141
+ names_list = [names[k] for k in ordered]
142
  elif isinstance(names, list):
143
  names_list = names
144
  else:
 
151
 
152
  return [str(x) for x in names_list]
153
 
 
154
  def download_dataset(api_key, workspace, project, version):
155
+ """Download Roboflow dataset in 'yolov8' layout (works fine for RT-DETR variants)."""
156
  try:
157
  rf = Roboflow(api_key=api_key)
158
  proj = rf.workspace(workspace).project(project)
 
163
  with open(data_yaml_path, 'r') as f:
164
  data_yaml = yaml.safe_load(f)
165
 
 
166
  class_names = _extract_class_names(data_yaml)
167
  try:
168
  nc = int(data_yaml.get('nc', len(class_names)))
 
179
  logging.error(f"Failed to download {workspace}/{project}/v{version}: {e}")
180
  return None, [], [], None
181
 
 
182
  def label_path_for(img_path: str) -> str:
183
+ """Convert .../split/images/file.jpg -> .../split/labels/file.txt."""
184
  split_dir = os.path.dirname(os.path.dirname(img_path)) # .../split
185
  base = os.path.splitext(os.path.basename(img_path))[0] + '.txt'
186
  return os.path.join(split_dir, 'labels', base)
187
 
 
188
  def gather_class_counts(dataset_info, class_mapping):
189
  """
190
+ Count per final class how many images contain that class at least once (counted once per image).
191
+ class_mapping: original_name -> final_name (or None if removed).
192
  """
193
  if not dataset_info:
194
  return {}
195
 
196
+ final_names = set(v for v in class_mapping.values() if v is not None)
197
  counts = {name: 0 for name in final_names}
198
 
199
  for loc, names, splits, _ in dataset_info:
200
+ id_to_name = {idx: class_mapping.get(n, None) for idx, n in enumerate(names)}
 
 
 
201
 
202
  for split in splits:
203
  labels_dir = os.path.join(loc, split, 'labels')
 
224
 
225
  return counts
226
 
 
227
  def finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress=gr.Progress()):
228
+ """Merge datasets following mapping and per-class image limits."""
229
  merged_dir = 'rolo_merged_dataset'
230
  if os.path.exists(merged_dir):
231
  shutil.rmtree(merged_dir, onerror=handle_remove_readonly)
 
235
  os.makedirs(os.path.join(merged_dir, split, 'images'), exist_ok=True)
236
  os.makedirs(os.path.join(merged_dir, split, 'labels'), exist_ok=True)
237
 
238
+ active_classes = sorted(set([cls for cls, limit in class_limits.items() if limit > 0]))
 
 
239
  final_class_map = {name: i for i, name in enumerate(active_classes)}
240
 
241
+ # Collect candidates
242
  all_images = []
243
  for loc, _, splits, _ in dataset_info:
244
  for split in splits:
 
253
  progress(0.2, desc="Selecting images based on limits...")
254
  selected_images = []
255
  current_counts = {cls: 0 for cls in active_classes}
 
 
256
  loc_to_names = {info[0]: info[1] for info in dataset_info}
257
 
258
  for img_path, split, source_loc in progress.tqdm(all_images, desc="Analyzing images"):
 
278
 
279
  if not image_classes:
280
  continue
 
 
281
  if any(current_counts[c] >= class_limits[c] for c in image_classes):
282
  continue
283
 
 
292
  out_lbl = os.path.join(merged_dir, split, 'labels', os.path.basename(lbl_path))
293
  shutil.copy(img_path, out_img)
294
 
 
295
  source_loc = None
296
  for info in dataset_info:
297
  if img_path.startswith(info[0]):
 
327
 
328
  return f"Dataset finalized with {len(selected_images)} images.", os.path.abspath(merged_dir)
329
 
330
+ # ------------------------------
331
+ # RT-DETRv2 backend helpers
332
+ # ------------------------------
333
+
334
+ def ensure_repo(repo_dir: str, repo_url: str = RTDETRV2_REPO_URL):
335
+ """Clone the repo into repo_dir if not present."""
336
+ if os.path.isdir(repo_dir) and os.path.isdir(os.path.join(repo_dir, ".git")):
337
+ return
338
+ os.makedirs(os.path.dirname(repo_dir), exist_ok=True)
339
+ logging.info(f"Cloning RT-DETRv2 repo into {repo_dir} ...")
340
+ cmd = ["git", "clone", "--depth", "1", repo_url, repo_dir]
341
+ subprocess.run(cmd, check=True)
342
+
343
+ def make_train_command(template: str, data_yaml: str, epochs: int, batch: int, imgsz: int,
344
+ lr: float, optimizer: str, run_name: str, output_dir: str) -> str:
345
+ return template.format(
346
+ data_yaml=data_yaml,
347
+ epochs=int(epochs),
348
+ batch=int(batch),
349
+ imgsz=int(imgsz),
350
+ lr=float(lr),
351
+ optimizer=str(optimizer),
352
+ run_name=str(run_name),
353
+ output_dir=output_dir
354
+ )
355
+
356
+ _METRIC_PATTERNS = [
357
+ # add more patterns if your repo prints differently
358
+ (re.compile(r"mAP@0\.5[:/]?0\.95[^0-9]*([0-9]*\.?[0-9]+)"), "mAP50_95"),
359
+ (re.compile(r"mAP50[^0-9]*([0-9]*\.?[0-9]+)"), "mAP50"),
360
+ (re.compile(r"\bval[_/ ]?loss[^0-9\-]*([0-9]*\.?[0-9]+)"), "val_loss"),
361
+ (re.compile(r"\btrain[_/ ]?loss[^0-9\-]*([0-9]*\.?[0-9]+)"), "train_loss"),
362
+ (re.compile(r"\bepoch[^0-9]*([0-9]+)"), "epoch"),
363
+ ]
364
+
365
+ def parse_metrics_from_line(line: str):
366
+ result = {}
367
+ for pat, key in _METRIC_PATTERNS:
368
+ m = pat.search(line)
369
+ if m:
370
+ val = m.group(1)
371
+ try:
372
+ result[key] = int(val) if key == "epoch" else float(val)
373
+ except Exception:
374
+ pass
375
+ return result
376
+
377
+ def guess_final_weights(output_dir: str):
378
+ """
379
+ Try to locate a 'best' checkpoint in output_dir.
380
+ Supports .pt/.pth/.pdparams etc. Return first match or None.
381
+ """
382
+ patterns = [
383
+ os.path.join(output_dir, "**", "best.*"),
384
+ os.path.join(output_dir, "**", "best_model.*"),
385
+ os.path.join(output_dir, "**", "checkpoint_best.*"),
386
+ ]
387
+ for p in patterns:
388
+ hits = glob.glob(p, recursive=True)
389
+ if hits:
390
+ return hits[0]
391
+ return None
392
 
393
  # ------------------------------
394
  # Gradio UI Event Handlers
 
428
  failures.append((raw, f"DownloadError: {ws}/{proj}/v{ver}"))
429
 
430
  if not dataset_info:
 
431
  msg = "No datasets were loaded successfully.\n" + "\n".join([f"- {u}: {why}" for u, why in failures[:10]])
432
  raise gr.Error(msg)
433
 
434
+ # ensure names are strings before sorting
435
  all_names = sorted({str(n) for _, names, _, _ in dataset_info for n in names})
436
  class_map = {name: name for name in all_names}
437
 
 
438
  initial_counts = gather_class_counts(dataset_info, class_map)
439
  df_data = [[name, name, initial_counts.get(name, 0), False] for name in all_names]
440
  status_text = "Datasets loaded successfully."
441
  if failures:
442
  status_text += f" ({len(dataset_info)} OK, {len(failures)} failed; see console logs)."
443
 
444
+ # FIX: gr.update(...) (not gr.DataFrame.update)
445
+ return status_text, dataset_info, gr.update(
446
  value=pd.DataFrame(df_data, columns=["Original Name", "Rename To", "Max Images", "Remove"])
447
  )
448
 
 
449
  def update_class_counts_handler(class_df, dataset_info):
450
+ """Live preview of merged class counts given the current mapping/removals."""
 
 
 
 
451
  if class_df is None or not dataset_info:
452
  return None
453
 
 
454
  class_df = pd.DataFrame(class_df)
455
  mapping = {}
456
  for _, row in class_df.iterrows():
 
460
  else:
461
  mapping[orig] = row["Rename To"]
462
 
 
463
  final_names = sorted(set(v for v in mapping.values() if v))
464
  counts = {k: 0 for k in final_names}
465
 
466
  for loc, names, splits, _ in dataset_info:
467
+ id_to_final = {idx: mapping.get(n, None) for idx, n in enumerate(names)}
 
 
468
 
469
  for split in splits:
470
  labels_dir = os.path.join(loc, split, 'labels')
 
492
  summary_df = pd.DataFrame(list(counts.items()), columns=["Final Class Name", "Est. Total Images"])
493
  return summary_df
494
 
 
495
  def finalize_handler(dataset_info, class_df, progress=gr.Progress()):
496
+ """Create the merged dataset directory with relabeled .txts and data.yaml."""
497
  if not dataset_info:
498
  raise gr.Error("Load datasets first in Tab 1.")
499
  if class_df is None:
500
  raise gr.Error("Class data is missing.")
501
 
 
502
  class_df = pd.DataFrame(class_df)
503
  class_mapping = {}
504
  class_limits = {}
 
508
  continue
509
  final_name = row["Rename To"]
510
  class_mapping[orig] = final_name
 
511
  class_limits[final_name] = class_limits.get(final_name, 0) + int(row["Max Images"])
512
 
513
  status, path = finalize_merged_dataset(dataset_info, class_mapping, class_limits, progress)
514
  return status, path
515
 
516
+ def training_handler_rtdetrv2(dataset_path, repo_dir, model_choice, run_name, epochs, batch, imgsz, lr, opt,
517
+ cmd_template, progress=gr.Progress()):
518
+ """
519
+ Train using RT-DETRv2 repo via a configurable command template.
520
+ We stream logs, parse simple metrics when patterns match, and try to locate a best checkpoint on completion.
521
+ """
522
  if not dataset_path:
523
  raise gr.Error("Finalize a dataset in Tab 2 before training.")
524
 
525
+ # Make sure repo exists
526
+ try:
527
+ ensure_repo(repo_dir)
528
+ except subprocess.CalledProcessError as e:
529
+ raise gr.Error(f"Failed to clone RT-DETRv2 repo: {e}")
530
+
531
+ # Prepare output directory
532
+ output_dir = os.path.join("runs", "train", str(run_name))
533
+ os.makedirs(output_dir, exist_ok=True)
534
+
535
+ data_yaml = os.path.join(dataset_path, "data.yaml")
536
+ if not os.path.isfile(data_yaml):
537
+ raise gr.Error(f"'data.yaml' was not found in: {dataset_path}")
538
+
539
+ # Build the command
540
+ cmd = make_train_command(
541
+ template=cmd_template,
542
+ data_yaml=data_yaml,
543
+ epochs=int(epochs),
544
+ batch=int(batch),
545
+ imgsz=int(imgsz),
546
+ lr=float(lr),
547
+ optimizer=str(opt),
548
+ run_name=str(run_name),
549
+ output_dir=output_dir
550
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
 
552
+ # Launch training subprocess in repo_dir
553
+ logging.info(f"Running training command in {repo_dir}: {cmd}")
554
+ proc = subprocess.Popen(
555
+ cmd,
556
+ cwd=repo_dir,
557
+ shell=True,
558
+ stdout=subprocess.PIPE,
559
+ stderr=subprocess.STDOUT,
560
+ bufsize=1,
561
+ universal_newlines=True,
562
+ env={**os.environ} # inherit env (CUDA, etc.)
563
+ )
564
 
565
+ # Live metrics
566
  history = {k: [] for k in ['epoch', 'train_loss', 'val_loss', 'mAP50', 'mAP50_95']}
567
+ last_epoch = 0
568
+
569
+ # Stream logs and parse
570
+ for line in iter(proc.stdout.readline, ''):
571
+ line = line.rstrip()
572
+ # Update progress indeterminately (we don't know total epochs from logs generically)
573
+ if "epoch" in line.lower():
574
+ progress(0.0, desc=line[-120:]) # show last part of the line
575
+ else:
576
+ progress(0.0, desc=line[-120:])
577
+
578
+ metrics = parse_metrics_from_line(line)
579
+ if metrics:
580
+ for k, v in metrics.items():
581
+ history[k].append(v)
582
+ # Plot when we detect an epoch number or mAP/loss update
583
+ # Plot Loss
584
+ fig_loss = plt.figure()
585
+ ax_loss = fig_loss.add_subplot(111)
586
+ ax_loss.plot(history['epoch'], history['train_loss'], "o-", label='Train Loss')
587
+ ax_loss.plot(history['epoch'], history['val_loss'], "o-", label='Val Loss')
588
+ ax_loss.legend()
589
+ ax_loss.set_title("Loss")
590
+
591
+ # Plot mAP
592
+ fig_map = plt.figure()
593
+ ax_map = fig_map.add_subplot(111)
594
+ ax_map.plot(history['epoch'], history['mAP50'], "o-", label='mAP@0.5')
595
+ ax_map.plot(history['epoch'], history['mAP50_95'], "o-", label='mAP@0.5:0.95')
596
+ ax_map.legend()
597
+ ax_map.set_title("mAP")
598
+
599
+ # Emit an update to the UI (status text is the last log line)
600
+ yield line[-200:], fig_loss, fig_map, None
601
+
602
+ proc.stdout.close()
603
+ ret = proc.wait()
604
+ if ret != 0:
605
+ raise gr.Error(f"Training process exited with code {ret}. Check console/logs for details.")
606
+
607
+ # Try to locate a best checkpoint
608
+ final_ckpt = guess_final_weights(output_dir)
609
+ if final_ckpt and os.path.isfile(final_ckpt):
610
+ yield "Training complete!", None, None, gr.File.update(value=final_ckpt, visible=True)
611
+ else:
612
+ # Still complete, but we couldn't find a checkpoint automatically
613
+ yield "Training finished. Could not auto-detect 'best' checkpoint; please check the output directory.", None, None, gr.update(visible=False)
614
 
615
  def upload_handler(model_file, hf_token, hf_repo, gh_token, gh_repo, progress=gr.Progress()):
616
  """Handles model upload to Hugging Face and GitHub."""
 
668
  progress(1)
669
  return hf_status, gh_status
670
 
 
671
  # ------------------------------
672
  # Gradio UI
673
  # ------------------------------
674
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="sky")) as app:
675
+ gr.Markdown("# Rolo: RT-DETRv2 Training Dashboard (Supervisely Ecosystem Backend)")
676
 
677
  # State variables
678
  dataset_info_state = gr.State([])
 
707
  finalize_status = gr.Textbox(label="Status", interactive=False)
708
 
709
  with gr.TabItem("3. Configure & Train"):
710
+ gr.Markdown("### Set Hyperparameters and Train with RT-DETRv2")
711
  with gr.Row():
712
  with gr.Column(scale=1):
713
+ model_choice_dd = gr.Dropdown(
714
+ label="Model Choice (label only – adjust your command template to use the right config)",
715
+ choices=RTDETRV2_MODELS,
716
  value=DEFAULT_MODEL
717
  )
718
+ run_name_tb = gr.Textbox(label="Run Name", value="rtdetrv2_run_1")
719
  epochs_sl = gr.Slider(1, 500, 100, step=1, label="Epochs")
720
+ batch_sl = gr.Slider(1, 64, 16, step=1, label="Batch Size")
721
  imgsz_num = gr.Number(label="Image Size", value=640)
722
  lr_num = gr.Number(label="Learning Rate", value=0.001)
723
+ opt_dd = gr.Dropdown(["Adam", "AdamW", "SGD"], value="AdamW", label="Optimizer")
724
+
725
+ repo_dir_tb = gr.Textbox(label="RT-DETRv2 repo directory", value=DEFAULT_REPO_DIR)
726
+ cmd_template_tb = gr.Textbox(
727
+ label="Train command template",
728
+ value=(
729
+ "python tools/train.py "
730
+ "--data {data_yaml} "
731
+ "--epochs {epochs} "
732
+ "--batch {batch} "
733
+ "--imgsz {imgsz} "
734
+ "--lr {lr} "
735
+ "--optimizer {optimizer} "
736
+ "--output {output_dir}"
737
+ ),
738
+ lines=4
739
+ )
740
  train_btn = gr.Button("Start Training", variant="primary")
741
  with gr.Column(scale=2):
742
+ train_status = gr.Textbox(label="Live Status / Logs", interactive=False)
743
  loss_plot = gr.Plot(label="Loss Curves")
744
  map_plot = gr.Plot(label="mAP Curves")
745
+ final_model_file = gr.File(label="Download Trained Model (best.*)", interactive=False, visible=False)
746
 
747
  with gr.TabItem("4. Upload Model"):
748
+ gr.Markdown("### Upload Your Trained Model\nAfter training, you can upload the best checkpoint to Hugging Face and/or GitHub.")
749
  with gr.Row():
750
  with gr.Column():
751
  gr.Markdown("#### Hugging Face")
752
  hf_token = gr.Textbox(label="Hugging Face API Token", type="password")
753
+ hf_repo = gr.Textbox(label="Hugging Face Repo ID", placeholder="e.g., username/my-rtdetrv2-model")
754
  with gr.Column():
755
  gr.Markdown("#### GitHub")
756
  gh_token = gr.Textbox(label="GitHub Personal Access Token", type="password")
757
+ gh_repo = gr.Textbox(label="GitHub Repo", placeholder="e.g., username/my-rtdetrv2-repo")
758
  upload_btn = gr.Button("Upload Model", variant="primary")
759
  with gr.Row():
760
  hf_status = gr.Textbox(label="Hugging Face Status", interactive=False)
 
777
  outputs=[finalize_status, final_dataset_path_state]
778
  )
779
  train_btn.click(
780
+ fn=training_handler_rtdetrv2,
781
+ inputs=[
782
+ final_dataset_path_state, # dataset_path
783
+ repo_dir_tb, # repo_dir
784
+ model_choice_dd, # model_choice (label only)
785
+ run_name_tb,
786
+ epochs_sl,
787
+ batch_sl,
788
+ imgsz_num,
789
+ lr_num,
790
+ opt_dd,
791
+ cmd_template_tb
792
+ ],
793
  outputs=[train_status, loss_plot, map_plot, final_model_file]
794
  )
795
  upload_btn.click(
 
799
  )
800
 
801
  if __name__ == "__main__":
802
+ # If Ultralytics warnings annoy you, set: export YOLO_CONFIG_DIR=/tmp/Ultralytics
 
803
  app.launch(debug=True)