yeq6x commited on
Commit
4cf412e
·
1 Parent(s): 325c528

Enhance run_training function in app.py to yield artifacts alongside checkpoints. Update error handling to include artifacts in log outputs, and add functionality to track and expose dataset configuration and script files for download. Modify UI to display scripts and configuration files, improving user experience and accessibility.

Browse files
Files changed (1) hide show
  1. app.py +35 -10
app.py CHANGED
@@ -618,13 +618,14 @@ def run_training(
618
  # Basic validation
619
  log_buf = ""
620
  ckpts: List[str] = []
 
621
  if not output_name.strip():
622
  log_buf += "[ERROR] OUTPUT NAME is required.\n"
623
- yield (log_buf, ckpts)
624
  return
625
  if not caption.strip():
626
  log_buf += "[ERROR] CAPTION is required.\n"
627
- yield (log_buf, ckpts)
628
  return
629
 
630
  # Ensure /auto holds helper files expected by the script
@@ -644,11 +645,11 @@ def run_training(
644
  base_files = _extract_paths(image_uploads)
645
  if not base_files:
646
  log_buf += "[ERROR] No images uploaded for IMAGE_FOLDER.\n"
647
- yield (log_buf, ckpts)
648
  return
649
  base_filenames = _copy_uploads(base_files, img_dir)
650
  log_buf += f"[QIE] Copied {len(base_filenames)} base images to {img_dir}\n"
651
- yield (log_buf, ckpts)
652
 
653
  # Prepare control sets
654
  control_upload_sets = [
@@ -664,7 +665,7 @@ def run_training(
664
  # Require control_0; others optional
665
  if not control_upload_sets[0]:
666
  log_buf += "[ERROR] control_0 images are required.\n"
667
- yield (log_buf, ckpts)
668
  return
669
 
670
  control_dirs: List[Optional[str]] = []
@@ -679,7 +680,7 @@ def run_training(
679
  _copy_uploads(uploads, cdir)
680
  control_dirs.append(folder_name)
681
  log_buf += f"[QIE] Copied {len(uploads)} control_{i} images to {cdir}\n"
682
- yield (log_buf, ckpts)
683
 
684
  # Metadata.jsonl will be generated by create_image_caption_json.py in train_QIE.sh
685
 
@@ -705,6 +706,9 @@ def run_training(
705
  log_buf += f"[QIE] Updated dataset config: resolution=({train_res_w},{train_res_h}), batch_size={train_batch_size}, control_res=({control_res_w},{control_res_h})\n"
706
  except Exception as e:
707
  log_buf += f"[QIE] WARN: failed to update dataset config: {e}\n"
 
 
 
708
 
709
  # Resolve models_root and set output_dir_base to the unique dataset dir
710
  models_root = MODELS_ROOT_RUNTIME
@@ -742,7 +746,19 @@ def run_training(
742
  log_buf += f"[QIE] Running script: {tmp_script}\n"
743
  out_dir = os.path.join(out_base, output_name.strip())
744
  ckpts = _list_checkpoints(out_dir)
745
- yield (log_buf, ckpts)
 
 
 
 
 
 
 
 
 
 
 
 
746
 
747
  # Run and stream output
748
  # Ensure child Python processes are unbuffered for real-time logs
@@ -767,7 +783,11 @@ def run_training(
767
  i += 1
768
  if i % 30 == 0:
769
  ckpts = _list_checkpoints(out_dir)
770
- yield (log_buf, ckpts)
 
 
 
 
771
  finally:
772
  code = proc.wait()
773
  # Try to locate latest LoRA file for download
@@ -778,7 +798,11 @@ def run_training(
778
  pass
779
  lora_path = ckpts[0] if ckpts else None
780
  log_buf += f"[QIE] Exit code: {code}\n"
781
- yield (log_buf, ckpts)
 
 
 
 
782
 
783
 
784
  def build_ui() -> gr.Blocks:
@@ -934,6 +958,7 @@ def build_ui() -> gr.Blocks:
934
  run_btn = gr.Button("Start Training", variant="primary")
935
  logs = gr.Textbox(label="Logs", lines=20)
936
  ckpt_files = gr.Files(label="Checkpoints (live)", interactive=False)
 
937
 
938
  # moved max_epochs/save_every above next to OUTPUT NAME
939
 
@@ -964,7 +989,7 @@ def build_ui() -> gr.Blocks:
964
  tr_w, tr_h, train_bs, cr_w, cr_h, te_bs,
965
  seed_input, max_epochs, save_every,
966
  ],
967
- outputs=[logs, ckpt_files],
968
  )
969
 
970
  with gr.TabItem("Prompt Generator"):
 
618
  # Basic validation
619
  log_buf = ""
620
  ckpts: List[str] = []
621
+ artifacts: List[str] = []
622
  if not output_name.strip():
623
  log_buf += "[ERROR] OUTPUT NAME is required.\n"
624
+ yield (log_buf, ckpts, artifacts)
625
  return
626
  if not caption.strip():
627
  log_buf += "[ERROR] CAPTION is required.\n"
628
+ yield (log_buf, ckpts, artifacts)
629
  return
630
 
631
  # Ensure /auto holds helper files expected by the script
 
645
  base_files = _extract_paths(image_uploads)
646
  if not base_files:
647
  log_buf += "[ERROR] No images uploaded for IMAGE_FOLDER.\n"
648
+ yield (log_buf, ckpts, artifacts)
649
  return
650
  base_filenames = _copy_uploads(base_files, img_dir)
651
  log_buf += f"[QIE] Copied {len(base_filenames)} base images to {img_dir}\n"
652
+ yield (log_buf, ckpts, artifacts)
653
 
654
  # Prepare control sets
655
  control_upload_sets = [
 
665
  # Require control_0; others optional
666
  if not control_upload_sets[0]:
667
  log_buf += "[ERROR] control_0 images are required.\n"
668
+ yield (log_buf, ckpts, artifacts)
669
  return
670
 
671
  control_dirs: List[Optional[str]] = []
 
680
  _copy_uploads(uploads, cdir)
681
  control_dirs.append(folder_name)
682
  log_buf += f"[QIE] Copied {len(uploads)} control_{i} images to {cdir}\n"
683
+ yield (log_buf, ckpts, artifacts)
684
 
685
  # Metadata.jsonl will be generated by create_image_caption_json.py in train_QIE.sh
686
 
 
706
  log_buf += f"[QIE] Updated dataset config: resolution=({train_res_w},{train_res_h}), batch_size={train_batch_size}, control_res=({control_res_w},{control_res_h})\n"
707
  except Exception as e:
708
  log_buf += f"[QIE] WARN: failed to update dataset config: {e}\n"
709
+ # Expose dataset config for download (if exists)
710
+ if os.path.isfile(ds_conf):
711
+ artifacts = [ds_conf]
712
 
713
  # Resolve models_root and set output_dir_base to the unique dataset dir
714
  models_root = MODELS_ROOT_RUNTIME
 
746
  log_buf += f"[QIE] Running script: {tmp_script}\n"
747
  out_dir = os.path.join(out_base, output_name.strip())
748
  ckpts = _list_checkpoints(out_dir)
749
+ # Copy the final script to dataset dir for download
750
+ used_script_path = os.path.join(out_base, "train_QIE_used.sh")
751
+ try:
752
+ shutil.copy2(str(tmp_script), used_script_path)
753
+ try:
754
+ os.chmod(used_script_path, 0o755)
755
+ except Exception:
756
+ pass
757
+ if used_script_path not in artifacts:
758
+ artifacts.append(used_script_path)
759
+ except Exception:
760
+ pass
761
+ yield (log_buf, ckpts, artifacts)
762
 
763
  # Run and stream output
764
  # Ensure child Python processes are unbuffered for real-time logs
 
783
  i += 1
784
  if i % 30 == 0:
785
  ckpts = _list_checkpoints(out_dir)
786
+ # Try to add metadata.jsonl once available
787
+ metadata_json = os.path.join(out_base, "metadata.jsonl")
788
+ if os.path.isfile(metadata_json) and metadata_json not in artifacts:
789
+ artifacts.append(metadata_json)
790
+ yield (log_buf, ckpts, artifacts)
791
  finally:
792
  code = proc.wait()
793
  # Try to locate latest LoRA file for download
 
798
  pass
799
  lora_path = ckpts[0] if ckpts else None
800
  log_buf += f"[QIE] Exit code: {code}\n"
801
+ # Final attempt to include metadata.jsonl
802
+ metadata_json = os.path.join(out_base, "metadata.jsonl")
803
+ if os.path.isfile(metadata_json) and metadata_json not in artifacts:
804
+ artifacts.append(metadata_json)
805
+ yield (log_buf, ckpts, artifacts)
806
 
807
 
808
  def build_ui() -> gr.Blocks:
 
958
  run_btn = gr.Button("Start Training", variant="primary")
959
  logs = gr.Textbox(label="Logs", lines=20)
960
  ckpt_files = gr.Files(label="Checkpoints (live)", interactive=False)
961
+ scripts_files = gr.Files(label="Scripts & Config (live)", interactive=False)
962
 
963
  # moved max_epochs/save_every above next to OUTPUT NAME
964
 
 
989
  tr_w, tr_h, train_bs, cr_w, cr_h, te_bs,
990
  seed_input, max_epochs, save_every,
991
  ],
992
+ outputs=[logs, ckpt_files, scripts_files],
993
  )
994
 
995
  with gr.TabItem("Prompt Generator"):