Spaces:
Running
on
Zero
Running
on
Zero
Enhance app.py by adding an override option for run name in the _prepare_script function and updating the run_training function to use an auto-generated dataset directory name. Modify UI to reflect changes from dataset_name to output_name for better clarity in user input.
Browse files
app.py
CHANGED
|
@@ -165,7 +165,8 @@ def _prepare_script(
|
|
| 165 |
dataset_config: Optional[str] = None,
|
| 166 |
override_max_epochs: Optional[int] = None,
|
| 167 |
override_save_every: Optional[int] = None,
|
| 168 |
-
|
|
|
|
| 169 |
"""Create a temporary copy of train_QIE.sh with injected variables.
|
| 170 |
|
| 171 |
Only variables that must vary per-run are replaced. The rest of the script
|
|
@@ -251,6 +252,8 @@ def _prepare_script(
|
|
| 251 |
if override_save_every is not None and override_save_every > 0:
|
| 252 |
txt = re.sub(r"--save_every_n_epochs\s+\d+",
|
| 253 |
f"--save_every_n_epochs {override_save_every}", txt)
|
|
|
|
|
|
|
| 254 |
|
| 255 |
# Write to a temp file alongside this repo for easier inspection
|
| 256 |
run_dir = TRAINING_DIR / ".gradio_runs"
|
|
@@ -332,7 +335,7 @@ def _startup_clone_musubi_tuner() -> None:
|
|
| 332 |
|
| 333 |
@spaces.GPU(duration=7200)
|
| 334 |
def run_training(
|
| 335 |
-
|
| 336 |
caption: str,
|
| 337 |
image_uploads: Any,
|
| 338 |
control0_uploads: Any,
|
|
@@ -350,8 +353,8 @@ def run_training(
|
|
| 350 |
save_every: int,
|
| 351 |
) -> Iterable[str]:
|
| 352 |
# Basic validation
|
| 353 |
-
if not
|
| 354 |
-
yield "[ERROR]
|
| 355 |
return
|
| 356 |
if not caption.strip():
|
| 357 |
yield "[ERROR] CAPTION is required."
|
|
@@ -362,7 +365,9 @@ def run_training(
|
|
| 362 |
# Resolve data root and create dataset directories (auto-decide)
|
| 363 |
global DATA_ROOT_RUNTIME
|
| 364 |
DATA_ROOT_RUNTIME = _ensure_data_root(None)
|
| 365 |
-
|
|
|
|
|
|
|
| 366 |
ds_dir = os.path.join(DATA_ROOT_RUNTIME, ds_name)
|
| 367 |
img_folder_name = DEFAULT_IMAGE_FOLDER
|
| 368 |
img_dir = os.path.join(ds_dir, img_folder_name)
|
|
@@ -435,6 +440,7 @@ def run_training(
|
|
| 435 |
dataset_config=ds_conf,
|
| 436 |
override_max_epochs=max_epochs if max_epochs and max_epochs > 0 else None,
|
| 437 |
override_save_every=save_every if save_every and save_every > 0 else None,
|
|
|
|
| 438 |
)
|
| 439 |
|
| 440 |
|
|
@@ -470,7 +476,7 @@ def build_ui() -> gr.Blocks:
|
|
| 470 |
""")
|
| 471 |
|
| 472 |
with gr.Row():
|
| 473 |
-
|
| 474 |
caption = gr.Textbox(label="CAPTION", placeholder="A photo of ...", lines=2)
|
| 475 |
|
| 476 |
with gr.Row():
|
|
@@ -509,7 +515,7 @@ def build_ui() -> gr.Blocks:
|
|
| 509 |
run_btn.click(
|
| 510 |
fn=run_training,
|
| 511 |
inputs=[
|
| 512 |
-
|
| 513 |
ctrl0_files, ctrl1_files, ctrl2_files, ctrl3_files, ctrl4_files, ctrl5_files, ctrl6_files, ctrl7_files,
|
| 514 |
models_root, output_dir_base, dataset_config,
|
| 515 |
max_epochs, save_every,
|
|
|
|
| 165 |
dataset_config: Optional[str] = None,
|
| 166 |
override_max_epochs: Optional[int] = None,
|
| 167 |
override_save_every: Optional[int] = None,
|
| 168 |
+
override_run_name: Optional[str] = None,
|
| 169 |
+
) -> Path:
|
| 170 |
"""Create a temporary copy of train_QIE.sh with injected variables.
|
| 171 |
|
| 172 |
Only variables that must vary per-run are replaced. The rest of the script
|
|
|
|
| 252 |
if override_save_every is not None and override_save_every > 0:
|
| 253 |
txt = re.sub(r"--save_every_n_epochs\s+\d+",
|
| 254 |
f"--save_every_n_epochs {override_save_every}", txt)
|
| 255 |
+
if override_run_name:
|
| 256 |
+
txt = re.sub(r"^RUN_NAME=.*$", f"RUN_NAME={_bash_quote(override_run_name)}", txt, flags=re.MULTILINE)
|
| 257 |
|
| 258 |
# Write to a temp file alongside this repo for easier inspection
|
| 259 |
run_dir = TRAINING_DIR / ".gradio_runs"
|
|
|
|
| 335 |
|
| 336 |
@spaces.GPU(duration=7200)
|
| 337 |
def run_training(
|
| 338 |
+
output_name: str,
|
| 339 |
caption: str,
|
| 340 |
image_uploads: Any,
|
| 341 |
control0_uploads: Any,
|
|
|
|
| 353 |
save_every: int,
|
| 354 |
) -> Iterable[str]:
|
| 355 |
# Basic validation
|
| 356 |
+
if not output_name.strip():
|
| 357 |
+
yield "[ERROR] OUTPUT NAME is required."
|
| 358 |
return
|
| 359 |
if not caption.strip():
|
| 360 |
yield "[ERROR] CAPTION is required."
|
|
|
|
| 365 |
# Resolve data root and create dataset directories (auto-decide)
|
| 366 |
global DATA_ROOT_RUNTIME
|
| 367 |
DATA_ROOT_RUNTIME = _ensure_data_root(None)
|
| 368 |
+
# Auto-generate dataset directory name
|
| 369 |
+
import time
|
| 370 |
+
ds_name = f"dataset_{int(time.time())}"
|
| 371 |
ds_dir = os.path.join(DATA_ROOT_RUNTIME, ds_name)
|
| 372 |
img_folder_name = DEFAULT_IMAGE_FOLDER
|
| 373 |
img_dir = os.path.join(ds_dir, img_folder_name)
|
|
|
|
| 440 |
dataset_config=ds_conf,
|
| 441 |
override_max_epochs=max_epochs if max_epochs and max_epochs > 0 else None,
|
| 442 |
override_save_every=save_every if save_every and save_every > 0 else None,
|
| 443 |
+
override_run_name=output_name.strip(),
|
| 444 |
)
|
| 445 |
|
| 446 |
|
|
|
|
| 476 |
""")
|
| 477 |
|
| 478 |
with gr.Row():
|
| 479 |
+
output_name = gr.Textbox(label="OUTPUT NAME", placeholder="my_lora_output", lines=1)
|
| 480 |
caption = gr.Textbox(label="CAPTION", placeholder="A photo of ...", lines=2)
|
| 481 |
|
| 482 |
with gr.Row():
|
|
|
|
| 515 |
run_btn.click(
|
| 516 |
fn=run_training,
|
| 517 |
inputs=[
|
| 518 |
+
output_name, caption, images_input,
|
| 519 |
ctrl0_files, ctrl1_files, ctrl2_files, ctrl3_files, ctrl4_files, ctrl5_files, ctrl6_files, ctrl7_files,
|
| 520 |
models_root, output_dir_base, dataset_config,
|
| 521 |
max_epochs, save_every,
|