yeq6x commited on
Commit
fc9a363
·
1 Parent(s): 3e32a9e

Add checkpoint listing functionality in app.py to track and display model checkpoints during training. Update run_training to yield checkpoint information and enhance Gradio UI with checkpoint file outputs for improved user experience.

Browse files
Files changed (1) hide show
  1. app.py +45 -22
app.py CHANGED
@@ -7,6 +7,7 @@ import sys
7
  import tempfile
8
  from pathlib import Path
9
  from typing import Dict, Iterable, List, Optional, Any, Tuple
 
10
  import json
11
 
12
  import gradio as gr
@@ -173,6 +174,25 @@ def _copy_uploads(
173
  return used_names
174
 
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  def _prepare_script(
177
  dataset_name: str,
178
  caption: str,
@@ -445,11 +465,14 @@ def run_training(
445
  ) -> Iterable[tuple]:
446
  # Basic validation
447
  log_buf = ""
 
448
  if not output_name.strip():
449
- yield ("[ERROR] OUTPUT NAME is required.", None)
 
450
  return
451
  if not caption.strip():
452
- yield ("[ERROR] CAPTION is required.", None)
 
453
  return
454
 
455
  # Ensure /auto holds helper files expected by the script
@@ -468,11 +491,12 @@ def run_training(
468
  # Ingest uploads into dataset folders
469
  base_files = _extract_paths(image_uploads)
470
  if not base_files:
471
- yield ("[ERROR] No images uploaded for IMAGE_FOLDER.", None)
 
472
  return
473
  base_filenames = _copy_uploads(base_files, img_dir)
474
  log_buf += f"[QIE] Copied {len(base_filenames)} base images to {img_dir}\n"
475
- yield (log_buf, None)
476
 
477
  # Prepare control sets
478
  control_upload_sets = [
@@ -487,7 +511,8 @@ def run_training(
487
  ]
488
  # Require control_0; others optional
489
  if not control_upload_sets[0]:
490
- yield ("[ERROR] control_0 images are required.", None)
 
491
  return
492
 
493
  control_dirs: List[Optional[str]] = []
@@ -502,7 +527,7 @@ def run_training(
502
  _copy_uploads(uploads, cdir)
503
  control_dirs.append(folder_name)
504
  log_buf += f"[QIE] Copied {len(uploads)} control_{i} images to {cdir}\n"
505
- yield (log_buf, None)
506
 
507
  # Metadata.jsonl will be generated by create_image_caption_json.py in train_QIE.sh
508
 
@@ -545,7 +570,9 @@ def run_training(
545
  shell = _pick_shell()
546
  log_buf += f"[QIE] Using shell: {shell}\n"
547
  log_buf += f"[QIE] Running script: {tmp_script}\n"
548
- yield (log_buf, None)
 
 
549
 
550
  # Run and stream output
551
  proc = subprocess.Popen(
@@ -558,29 +585,24 @@ def run_training(
558
  )
559
  try:
560
  assert proc.stdout is not None
 
561
  for line in proc.stdout:
562
  log_buf += line
563
- yield (log_buf, None)
 
 
 
564
  finally:
565
  code = proc.wait()
566
  # Try to locate latest LoRA file for download
567
  lora_path = None
568
  try:
569
- out_dir = os.path.join(out_base, output_name.strip())
570
- if os.path.isdir(out_dir):
571
- cand = []
572
- for root, _, files in os.walk(out_dir):
573
- for fn in files:
574
- if fn.lower().endswith(".safetensors"):
575
- full = os.path.join(root, fn)
576
- cand.append((os.path.getmtime(full), full))
577
- if cand:
578
- cand.sort()
579
- lora_path = cand[-1][1]
580
  except Exception:
581
  pass
 
582
  log_buf += f"[QIE] Exit code: {code}\n"
583
- yield (log_buf, lora_path)
584
 
585
 
586
  def build_ui() -> gr.Blocks:
@@ -685,7 +707,8 @@ def build_ui() -> gr.Blocks:
685
 
686
  run_btn = gr.Button("Start Training", variant="primary")
687
  logs = gr.Textbox(label="Logs", lines=20)
688
- lora_file = gr.File(label="Download LoRA", interactive=False)
 
689
 
690
  with gr.Row():
691
  max_epochs = gr.Number(label="Max epochs (this run)", value=10, precision=0)
@@ -705,7 +728,7 @@ def build_ui() -> gr.Blocks:
705
  ctrl7_files, ctrl7_prefix, ctrl7_suffix,
706
  max_epochs, save_every,
707
  ],
708
- outputs=[logs, lora_file],
709
  )
710
 
711
  return demo
 
7
  import tempfile
8
  from pathlib import Path
9
  from typing import Dict, Iterable, List, Optional, Any, Tuple
10
+ import time
11
  import json
12
 
13
  import gradio as gr
 
174
  return used_names
175
 
176
 
177
+ def _list_checkpoints(out_dir: str, limit: int = 20) -> List[str]:
178
+ try:
179
+ if not out_dir or not os.path.isdir(out_dir):
180
+ return []
181
+ items: List[Tuple[float, str]] = []
182
+ for root, _, files in os.walk(out_dir):
183
+ for fn in files:
184
+ if fn.lower().endswith('.safetensors'):
185
+ full = os.path.join(root, fn)
186
+ try:
187
+ items.append((os.path.getmtime(full), full))
188
+ except Exception:
189
+ pass
190
+ items.sort(reverse=True)
191
+ return [p for _, p in items[:limit]]
192
+ except Exception:
193
+ return []
194
+
195
+
196
  def _prepare_script(
197
  dataset_name: str,
198
  caption: str,
 
465
  ) -> Iterable[tuple]:
466
  # Basic validation
467
  log_buf = ""
468
+ ckpts: List[str] = []
469
  if not output_name.strip():
470
+ log_buf += "[ERROR] OUTPUT NAME is required.\n"
471
+ yield (log_buf, ckpts, None)
472
  return
473
  if not caption.strip():
474
+ log_buf += "[ERROR] CAPTION is required.\n"
475
+ yield (log_buf, ckpts, None)
476
  return
477
 
478
  # Ensure /auto holds helper files expected by the script
 
491
  # Ingest uploads into dataset folders
492
  base_files = _extract_paths(image_uploads)
493
  if not base_files:
494
+ log_buf += "[ERROR] No images uploaded for IMAGE_FOLDER.\n"
495
+ yield (log_buf, ckpts, None)
496
  return
497
  base_filenames = _copy_uploads(base_files, img_dir)
498
  log_buf += f"[QIE] Copied {len(base_filenames)} base images to {img_dir}\n"
499
+ yield (log_buf, ckpts, None)
500
 
501
  # Prepare control sets
502
  control_upload_sets = [
 
511
  ]
512
  # Require control_0; others optional
513
  if not control_upload_sets[0]:
514
+ log_buf += "[ERROR] control_0 images are required.\n"
515
+ yield (log_buf, ckpts, None)
516
  return
517
 
518
  control_dirs: List[Optional[str]] = []
 
527
  _copy_uploads(uploads, cdir)
528
  control_dirs.append(folder_name)
529
  log_buf += f"[QIE] Copied {len(uploads)} control_{i} images to {cdir}\n"
530
+ yield (log_buf, ckpts, None)
531
 
532
  # Metadata.jsonl will be generated by create_image_caption_json.py in train_QIE.sh
533
 
 
570
  shell = _pick_shell()
571
  log_buf += f"[QIE] Using shell: {shell}\n"
572
  log_buf += f"[QIE] Running script: {tmp_script}\n"
573
+ out_dir = os.path.join(out_base, output_name.strip())
574
+ ckpts = _list_checkpoints(out_dir)
575
+ yield (log_buf, ckpts, None)
576
 
577
  # Run and stream output
578
  proc = subprocess.Popen(
 
585
  )
586
  try:
587
  assert proc.stdout is not None
588
+ i = 0
589
  for line in proc.stdout:
590
  log_buf += line
591
+ i += 1
592
+ if i % 30 == 0:
593
+ ckpts = _list_checkpoints(out_dir)
594
+ yield (log_buf, ckpts, None)
595
  finally:
596
  code = proc.wait()
597
  # Try to locate latest LoRA file for download
598
  lora_path = None
599
  try:
600
+ ckpts = _list_checkpoints(out_dir)
 
 
 
 
 
 
 
 
 
 
601
  except Exception:
602
  pass
603
+ lora_path = ckpts[0] if ckpts else None
604
  log_buf += f"[QIE] Exit code: {code}\n"
605
+ yield (log_buf, ckpts, lora_path)
606
 
607
 
608
  def build_ui() -> gr.Blocks:
 
707
 
708
  run_btn = gr.Button("Start Training", variant="primary")
709
  logs = gr.Textbox(label="Logs", lines=20)
710
+ ckpt_files = gr.Files(label="Checkpoints (live)", interactive=False)
711
+ lora_file = gr.File(label="Download LoRA (latest)", interactive=False)
712
 
713
  with gr.Row():
714
  max_epochs = gr.Number(label="Max epochs (this run)", value=10, precision=0)
 
728
  ctrl7_files, ctrl7_prefix, ctrl7_suffix,
729
  max_epochs, save_every,
730
  ],
731
+ outputs=[logs, ckpt_files, lora_file],
732
  )
733
 
734
  return demo