Spaces:
Running
on
Zero
Running
on
Zero
Enhance run_training function in app.py to yield artifacts alongside checkpoints. Update error handling to include artifacts in log outputs, and add functionality to track and expose dataset configuration and script files for download. Modify UI to display scripts and configuration files, improving user experience and accessibility.
Browse files
app.py
CHANGED
|
@@ -618,13 +618,14 @@ def run_training(
|
|
| 618 |
# Basic validation
|
| 619 |
log_buf = ""
|
| 620 |
ckpts: List[str] = []
|
|
|
|
| 621 |
if not output_name.strip():
|
| 622 |
log_buf += "[ERROR] OUTPUT NAME is required.\n"
|
| 623 |
-
yield (log_buf, ckpts)
|
| 624 |
return
|
| 625 |
if not caption.strip():
|
| 626 |
log_buf += "[ERROR] CAPTION is required.\n"
|
| 627 |
-
yield (log_buf, ckpts)
|
| 628 |
return
|
| 629 |
|
| 630 |
# Ensure /auto holds helper files expected by the script
|
|
@@ -644,11 +645,11 @@ def run_training(
|
|
| 644 |
base_files = _extract_paths(image_uploads)
|
| 645 |
if not base_files:
|
| 646 |
log_buf += "[ERROR] No images uploaded for IMAGE_FOLDER.\n"
|
| 647 |
-
yield (log_buf, ckpts)
|
| 648 |
return
|
| 649 |
base_filenames = _copy_uploads(base_files, img_dir)
|
| 650 |
log_buf += f"[QIE] Copied {len(base_filenames)} base images to {img_dir}\n"
|
| 651 |
-
yield (log_buf, ckpts)
|
| 652 |
|
| 653 |
# Prepare control sets
|
| 654 |
control_upload_sets = [
|
|
@@ -664,7 +665,7 @@ def run_training(
|
|
| 664 |
# Require control_0; others optional
|
| 665 |
if not control_upload_sets[0]:
|
| 666 |
log_buf += "[ERROR] control_0 images are required.\n"
|
| 667 |
-
yield (log_buf, ckpts)
|
| 668 |
return
|
| 669 |
|
| 670 |
control_dirs: List[Optional[str]] = []
|
|
@@ -679,7 +680,7 @@ def run_training(
|
|
| 679 |
_copy_uploads(uploads, cdir)
|
| 680 |
control_dirs.append(folder_name)
|
| 681 |
log_buf += f"[QIE] Copied {len(uploads)} control_{i} images to {cdir}\n"
|
| 682 |
-
yield (log_buf, ckpts)
|
| 683 |
|
| 684 |
# Metadata.jsonl will be generated by create_image_caption_json.py in train_QIE.sh
|
| 685 |
|
|
@@ -705,6 +706,9 @@ def run_training(
|
|
| 705 |
log_buf += f"[QIE] Updated dataset config: resolution=({train_res_w},{train_res_h}), batch_size={train_batch_size}, control_res=({control_res_w},{control_res_h})\n"
|
| 706 |
except Exception as e:
|
| 707 |
log_buf += f"[QIE] WARN: failed to update dataset config: {e}\n"
|
|
|
|
|
|
|
|
|
|
| 708 |
|
| 709 |
# Resolve models_root and set output_dir_base to the unique dataset dir
|
| 710 |
models_root = MODELS_ROOT_RUNTIME
|
|
@@ -742,7 +746,19 @@ def run_training(
|
|
| 742 |
log_buf += f"[QIE] Running script: {tmp_script}\n"
|
| 743 |
out_dir = os.path.join(out_base, output_name.strip())
|
| 744 |
ckpts = _list_checkpoints(out_dir)
|
| 745 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 746 |
|
| 747 |
# Run and stream output
|
| 748 |
# Ensure child Python processes are unbuffered for real-time logs
|
|
@@ -767,7 +783,11 @@ def run_training(
|
|
| 767 |
i += 1
|
| 768 |
if i % 30 == 0:
|
| 769 |
ckpts = _list_checkpoints(out_dir)
|
| 770 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 771 |
finally:
|
| 772 |
code = proc.wait()
|
| 773 |
# Try to locate latest LoRA file for download
|
|
@@ -778,7 +798,11 @@ def run_training(
|
|
| 778 |
pass
|
| 779 |
lora_path = ckpts[0] if ckpts else None
|
| 780 |
log_buf += f"[QIE] Exit code: {code}\n"
|
| 781 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 782 |
|
| 783 |
|
| 784 |
def build_ui() -> gr.Blocks:
|
|
@@ -934,6 +958,7 @@ def build_ui() -> gr.Blocks:
|
|
| 934 |
run_btn = gr.Button("Start Training", variant="primary")
|
| 935 |
logs = gr.Textbox(label="Logs", lines=20)
|
| 936 |
ckpt_files = gr.Files(label="Checkpoints (live)", interactive=False)
|
|
|
|
| 937 |
|
| 938 |
# moved max_epochs/save_every above next to OUTPUT NAME
|
| 939 |
|
|
@@ -964,7 +989,7 @@ def build_ui() -> gr.Blocks:
|
|
| 964 |
tr_w, tr_h, train_bs, cr_w, cr_h, te_bs,
|
| 965 |
seed_input, max_epochs, save_every,
|
| 966 |
],
|
| 967 |
-
outputs=[logs, ckpt_files],
|
| 968 |
)
|
| 969 |
|
| 970 |
with gr.TabItem("Prompt Generator"):
|
|
|
|
| 618 |
# Basic validation
|
| 619 |
log_buf = ""
|
| 620 |
ckpts: List[str] = []
|
| 621 |
+
artifacts: List[str] = []
|
| 622 |
if not output_name.strip():
|
| 623 |
log_buf += "[ERROR] OUTPUT NAME is required.\n"
|
| 624 |
+
yield (log_buf, ckpts, artifacts)
|
| 625 |
return
|
| 626 |
if not caption.strip():
|
| 627 |
log_buf += "[ERROR] CAPTION is required.\n"
|
| 628 |
+
yield (log_buf, ckpts, artifacts)
|
| 629 |
return
|
| 630 |
|
| 631 |
# Ensure /auto holds helper files expected by the script
|
|
|
|
| 645 |
base_files = _extract_paths(image_uploads)
|
| 646 |
if not base_files:
|
| 647 |
log_buf += "[ERROR] No images uploaded for IMAGE_FOLDER.\n"
|
| 648 |
+
yield (log_buf, ckpts, artifacts)
|
| 649 |
return
|
| 650 |
base_filenames = _copy_uploads(base_files, img_dir)
|
| 651 |
log_buf += f"[QIE] Copied {len(base_filenames)} base images to {img_dir}\n"
|
| 652 |
+
yield (log_buf, ckpts, artifacts)
|
| 653 |
|
| 654 |
# Prepare control sets
|
| 655 |
control_upload_sets = [
|
|
|
|
| 665 |
# Require control_0; others optional
|
| 666 |
if not control_upload_sets[0]:
|
| 667 |
log_buf += "[ERROR] control_0 images are required.\n"
|
| 668 |
+
yield (log_buf, ckpts, artifacts)
|
| 669 |
return
|
| 670 |
|
| 671 |
control_dirs: List[Optional[str]] = []
|
|
|
|
| 680 |
_copy_uploads(uploads, cdir)
|
| 681 |
control_dirs.append(folder_name)
|
| 682 |
log_buf += f"[QIE] Copied {len(uploads)} control_{i} images to {cdir}\n"
|
| 683 |
+
yield (log_buf, ckpts, artifacts)
|
| 684 |
|
| 685 |
# Metadata.jsonl will be generated by create_image_caption_json.py in train_QIE.sh
|
| 686 |
|
|
|
|
| 706 |
log_buf += f"[QIE] Updated dataset config: resolution=({train_res_w},{train_res_h}), batch_size={train_batch_size}, control_res=({control_res_w},{control_res_h})\n"
|
| 707 |
except Exception as e:
|
| 708 |
log_buf += f"[QIE] WARN: failed to update dataset config: {e}\n"
|
| 709 |
+
# Expose dataset config for download (if exists)
|
| 710 |
+
if os.path.isfile(ds_conf):
|
| 711 |
+
artifacts = [ds_conf]
|
| 712 |
|
| 713 |
# Resolve models_root and set output_dir_base to the unique dataset dir
|
| 714 |
models_root = MODELS_ROOT_RUNTIME
|
|
|
|
| 746 |
log_buf += f"[QIE] Running script: {tmp_script}\n"
|
| 747 |
out_dir = os.path.join(out_base, output_name.strip())
|
| 748 |
ckpts = _list_checkpoints(out_dir)
|
| 749 |
+
# Copy the final script to dataset dir for download
|
| 750 |
+
used_script_path = os.path.join(out_base, "train_QIE_used.sh")
|
| 751 |
+
try:
|
| 752 |
+
shutil.copy2(str(tmp_script), used_script_path)
|
| 753 |
+
try:
|
| 754 |
+
os.chmod(used_script_path, 0o755)
|
| 755 |
+
except Exception:
|
| 756 |
+
pass
|
| 757 |
+
if used_script_path not in artifacts:
|
| 758 |
+
artifacts.append(used_script_path)
|
| 759 |
+
except Exception:
|
| 760 |
+
pass
|
| 761 |
+
yield (log_buf, ckpts, artifacts)
|
| 762 |
|
| 763 |
# Run and stream output
|
| 764 |
# Ensure child Python processes are unbuffered for real-time logs
|
|
|
|
| 783 |
i += 1
|
| 784 |
if i % 30 == 0:
|
| 785 |
ckpts = _list_checkpoints(out_dir)
|
| 786 |
+
# Try to add metadata.jsonl once available
|
| 787 |
+
metadata_json = os.path.join(out_base, "metadata.jsonl")
|
| 788 |
+
if os.path.isfile(metadata_json) and metadata_json not in artifacts:
|
| 789 |
+
artifacts.append(metadata_json)
|
| 790 |
+
yield (log_buf, ckpts, artifacts)
|
| 791 |
finally:
|
| 792 |
code = proc.wait()
|
| 793 |
# Try to locate latest LoRA file for download
|
|
|
|
| 798 |
pass
|
| 799 |
lora_path = ckpts[0] if ckpts else None
|
| 800 |
log_buf += f"[QIE] Exit code: {code}\n"
|
| 801 |
+
# Final attempt to include metadata.jsonl
|
| 802 |
+
metadata_json = os.path.join(out_base, "metadata.jsonl")
|
| 803 |
+
if os.path.isfile(metadata_json) and metadata_json not in artifacts:
|
| 804 |
+
artifacts.append(metadata_json)
|
| 805 |
+
yield (log_buf, ckpts, artifacts)
|
| 806 |
|
| 807 |
|
| 808 |
def build_ui() -> gr.Blocks:
|
|
|
|
| 958 |
run_btn = gr.Button("Start Training", variant="primary")
|
| 959 |
logs = gr.Textbox(label="Logs", lines=20)
|
| 960 |
ckpt_files = gr.Files(label="Checkpoints (live)", interactive=False)
|
| 961 |
+
scripts_files = gr.Files(label="Scripts & Config (live)", interactive=False)
|
| 962 |
|
| 963 |
# moved max_epochs/save_every above next to OUTPUT NAME
|
| 964 |
|
|
|
|
| 989 |
tr_w, tr_h, train_bs, cr_w, cr_h, te_bs,
|
| 990 |
seed_input, max_epochs, save_every,
|
| 991 |
],
|
| 992 |
+
outputs=[logs, ckpt_files, scripts_files],
|
| 993 |
)
|
| 994 |
|
| 995 |
with gr.TabItem("Prompt Generator"):
|