yeq6x commited on
Commit
f03ecf2
·
1 Parent(s): 7c1bc29

Enhance app.py by adding an override option for run name in the _prepare_script function and updating the run_training function to use an auto-generated dataset directory name. Modify UI to reflect changes from dataset_name to output_name for better clarity in user input.

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -165,7 +165,8 @@ def _prepare_script(
165
  dataset_config: Optional[str] = None,
166
  override_max_epochs: Optional[int] = None,
167
  override_save_every: Optional[int] = None,
168
- ) -> Path:
 
169
  """Create a temporary copy of train_QIE.sh with injected variables.
170
 
171
  Only variables that must vary per-run are replaced. The rest of the script
@@ -251,6 +252,8 @@ def _prepare_script(
251
  if override_save_every is not None and override_save_every > 0:
252
  txt = re.sub(r"--save_every_n_epochs\s+\d+",
253
  f"--save_every_n_epochs {override_save_every}", txt)
 
 
254
 
255
  # Write to a temp file alongside this repo for easier inspection
256
  run_dir = TRAINING_DIR / ".gradio_runs"
@@ -332,7 +335,7 @@ def _startup_clone_musubi_tuner() -> None:
332
 
333
  @spaces.GPU(duration=7200)
334
  def run_training(
335
- dataset_name: str,
336
  caption: str,
337
  image_uploads: Any,
338
  control0_uploads: Any,
@@ -350,8 +353,8 @@ def run_training(
350
  save_every: int,
351
  ) -> Iterable[str]:
352
  # Basic validation
353
- if not dataset_name.strip():
354
- yield "[ERROR] DATASET_NAME is required."
355
  return
356
  if not caption.strip():
357
  yield "[ERROR] CAPTION is required."
@@ -362,7 +365,9 @@ def run_training(
362
  # Resolve data root and create dataset directories (auto-decide)
363
  global DATA_ROOT_RUNTIME
364
  DATA_ROOT_RUNTIME = _ensure_data_root(None)
365
- ds_name = dataset_name.strip()
 
 
366
  ds_dir = os.path.join(DATA_ROOT_RUNTIME, ds_name)
367
  img_folder_name = DEFAULT_IMAGE_FOLDER
368
  img_dir = os.path.join(ds_dir, img_folder_name)
@@ -435,6 +440,7 @@ def run_training(
435
  dataset_config=ds_conf,
436
  override_max_epochs=max_epochs if max_epochs and max_epochs > 0 else None,
437
  override_save_every=save_every if save_every and save_every > 0 else None,
 
438
  )
439
 
440
 
@@ -470,7 +476,7 @@ def build_ui() -> gr.Blocks:
470
  """)
471
 
472
  with gr.Row():
473
- dataset_name = gr.Textbox(label="DATASET_NAME (folder under DATA_ROOT)", placeholder="my_dataset", lines=1)
474
  caption = gr.Textbox(label="CAPTION", placeholder="A photo of ...", lines=2)
475
 
476
  with gr.Row():
@@ -509,7 +515,7 @@ def build_ui() -> gr.Blocks:
509
  run_btn.click(
510
  fn=run_training,
511
  inputs=[
512
- dataset_name, caption, images_input,
513
  ctrl0_files, ctrl1_files, ctrl2_files, ctrl3_files, ctrl4_files, ctrl5_files, ctrl6_files, ctrl7_files,
514
  models_root, output_dir_base, dataset_config,
515
  max_epochs, save_every,
 
165
  dataset_config: Optional[str] = None,
166
  override_max_epochs: Optional[int] = None,
167
  override_save_every: Optional[int] = None,
168
+ override_run_name: Optional[str] = None,
169
+ ) -> Path:
170
  """Create a temporary copy of train_QIE.sh with injected variables.
171
 
172
  Only variables that must vary per-run are replaced. The rest of the script
 
252
  if override_save_every is not None and override_save_every > 0:
253
  txt = re.sub(r"--save_every_n_epochs\s+\d+",
254
  f"--save_every_n_epochs {override_save_every}", txt)
255
+ if override_run_name:
256
+ txt = re.sub(r"^RUN_NAME=.*$", f"RUN_NAME={_bash_quote(override_run_name)}", txt, flags=re.MULTILINE)
257
 
258
  # Write to a temp file alongside this repo for easier inspection
259
  run_dir = TRAINING_DIR / ".gradio_runs"
 
335
 
336
  @spaces.GPU(duration=7200)
337
  def run_training(
338
+ output_name: str,
339
  caption: str,
340
  image_uploads: Any,
341
  control0_uploads: Any,
 
353
  save_every: int,
354
  ) -> Iterable[str]:
355
  # Basic validation
356
+ if not output_name.strip():
357
+ yield "[ERROR] OUTPUT NAME is required."
358
  return
359
  if not caption.strip():
360
  yield "[ERROR] CAPTION is required."
 
365
  # Resolve data root and create dataset directories (auto-decide)
366
  global DATA_ROOT_RUNTIME
367
  DATA_ROOT_RUNTIME = _ensure_data_root(None)
368
+ # Auto-generate dataset directory name
369
+ import time
370
+ ds_name = f"dataset_{int(time.time())}"
371
  ds_dir = os.path.join(DATA_ROOT_RUNTIME, ds_name)
372
  img_folder_name = DEFAULT_IMAGE_FOLDER
373
  img_dir = os.path.join(ds_dir, img_folder_name)
 
440
  dataset_config=ds_conf,
441
  override_max_epochs=max_epochs if max_epochs and max_epochs > 0 else None,
442
  override_save_every=save_every if save_every and save_every > 0 else None,
443
+ override_run_name=output_name.strip(),
444
  )
445
 
446
 
 
476
  """)
477
 
478
  with gr.Row():
479
+ output_name = gr.Textbox(label="OUTPUT NAME", placeholder="my_lora_output", lines=1)
480
  caption = gr.Textbox(label="CAPTION", placeholder="A photo of ...", lines=2)
481
 
482
  with gr.Row():
 
515
  run_btn.click(
516
  fn=run_training,
517
  inputs=[
518
+ output_name, caption, images_input,
519
  ctrl0_files, ctrl1_files, ctrl2_files, ctrl3_files, ctrl4_files, ctrl5_files, ctrl6_files, ctrl7_files,
520
  models_root, output_dir_base, dataset_config,
521
  max_epochs, save_every,