yeq6x commited on
Commit
59635a0
·
1 Parent(s): f03ecf2

Refactor run_training function in app.py to return tuples for error and log messages, enhancing clarity in output handling. Update UI to remove user input for models_root, output_dir_base, and dataset_config, which are now resolved at runtime. Implement logic to locate the latest LoRA file for download after training completion.

Browse files
Files changed (1) hide show
  1. app.py +42 -26
app.py CHANGED
@@ -346,18 +346,15 @@ def run_training(
346
  control5_uploads: Any,
347
  control6_uploads: Any,
348
  control7_uploads: Any,
349
- models_root: str,
350
- output_dir_base: str,
351
- dataset_config: str,
352
  max_epochs: int,
353
  save_every: int,
354
- ) -> Iterable[str]:
355
  # Basic validation
356
  if not output_name.strip():
357
- yield "[ERROR] OUTPUT NAME is required."
358
  return
359
  if not caption.strip():
360
- yield "[ERROR] CAPTION is required."
361
  return
362
 
363
  # Ensure /auto holds helper files expected by the script
@@ -376,10 +373,10 @@ def run_training(
376
  # Ingest uploads into dataset folders
377
  base_files = _extract_paths(image_uploads)
378
  if not base_files:
379
- yield "[ERROR] No images uploaded for IMAGE_FOLDER."
380
  return
381
  base_filenames = _copy_uploads(base_files, img_dir)
382
- yield f"[QIE] Copied {len(base_filenames)} base images to {img_dir}"
383
 
384
  # Prepare control sets
385
  control_upload_sets = [
@@ -408,14 +405,14 @@ def run_training(
408
  replicated = uploads * len(base_filenames)
409
  _copy_uploads(replicated, cdir, rename_to=base_filenames)
410
  else:
411
- yield f"[ERROR] control_{i}: file count {len(uploads)} must be 1 or {len(base_filenames)}."
412
  return
413
  control_dirs.append(folder_name)
414
  any_control = True
415
- yield f"[QIE] Copied {len(uploads)} control_{i} images to {cdir}"
416
 
417
  if not any_control:
418
- yield "[ERROR] At least one control folder is required for edit-plus training."
419
  return
420
 
421
  # Prepare script with user parameters
@@ -425,9 +422,15 @@ def run_training(
425
  ]
426
 
427
  # Decide dataset_config path with fallback to runtime auto dir
428
- ds_conf_input = (dataset_config or "").strip()
429
- ds_conf_runtime = str(Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml")
430
- ds_conf = ds_conf_input if (ds_conf_input and os.path.exists(ds_conf_input)) else ds_conf_runtime
 
 
 
 
 
 
431
 
432
  tmp_script = _prepare_script(
433
  dataset_name=ds_name,
@@ -435,8 +438,8 @@ def run_training(
435
  data_root=DATA_ROOT_RUNTIME,
436
  image_folder=img_folder_name,
437
  control_folders=control_folders,
438
- models_root=models_root.strip() or MODELS_ROOT_RUNTIME,
439
- output_dir_base=(output_dir_base.strip() or None),
440
  dataset_config=ds_conf,
441
  override_max_epochs=max_epochs if max_epochs and max_epochs > 0 else None,
442
  override_save_every=save_every if save_every and save_every > 0 else None,
@@ -445,8 +448,8 @@ def run_training(
445
 
446
 
447
  shell = _pick_shell()
448
- yield f"[QIE] Using shell: {shell}"
449
- yield f"[QIE] Running script: {tmp_script}"
450
 
451
  # Run and stream output
452
  proc = subprocess.Popen(
@@ -460,10 +463,26 @@ def run_training(
460
  try:
461
  assert proc.stdout is not None
462
  for line in proc.stdout:
463
- yield line.rstrip("\n")
464
  finally:
465
  code = proc.wait()
466
- yield f"[QIE] Exit code: {code}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
 
468
 
469
  def build_ui() -> gr.Blocks:
@@ -500,13 +519,11 @@ def build_ui() -> gr.Blocks:
500
  with gr.Row():
501
  ctrl7_files = gr.File(label="Upload control_7 images", file_count="multiple", type="filepath")
502
 
503
- with gr.Row():
504
- models_root = gr.Textbox(label="Models root", value=MODELS_ROOT_RUNTIME)
505
- output_dir_base = gr.Textbox(label="OUTPUT_DIR_BASE", value=DEFAULT_OUTPUT_DIR_BASE)
506
- dataset_config = gr.Textbox(label="DATASET_CONFIG", value=str(Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml"))
507
 
508
  run_btn = gr.Button("Start Training", variant="primary")
509
  logs = gr.Textbox(label="Logs", lines=20)
 
510
 
511
  with gr.Row():
512
  max_epochs = gr.Number(label="Max epochs (this run)", value=10, precision=0)
@@ -517,10 +534,9 @@ def build_ui() -> gr.Blocks:
517
  inputs=[
518
  output_name, caption, images_input,
519
  ctrl0_files, ctrl1_files, ctrl2_files, ctrl3_files, ctrl4_files, ctrl5_files, ctrl6_files, ctrl7_files,
520
- models_root, output_dir_base, dataset_config,
521
  max_epochs, save_every,
522
  ],
523
- outputs=logs,
524
  )
525
 
526
  return demo
 
346
  control5_uploads: Any,
347
  control6_uploads: Any,
348
  control7_uploads: Any,
 
 
 
349
  max_epochs: int,
350
  save_every: int,
351
+ ) -> Iterable[tuple]:
352
  # Basic validation
353
  if not output_name.strip():
354
+ yield ("[ERROR] OUTPUT NAME is required.", None)
355
  return
356
  if not caption.strip():
357
+ yield ("[ERROR] CAPTION is required.", None)
358
  return
359
 
360
  # Ensure /auto holds helper files expected by the script
 
373
  # Ingest uploads into dataset folders
374
  base_files = _extract_paths(image_uploads)
375
  if not base_files:
376
+ yield ("[ERROR] No images uploaded for IMAGE_FOLDER.", None)
377
  return
378
  base_filenames = _copy_uploads(base_files, img_dir)
379
+ yield (f"[QIE] Copied {len(base_filenames)} base images to {img_dir}", None)
380
 
381
  # Prepare control sets
382
  control_upload_sets = [
 
405
  replicated = uploads * len(base_filenames)
406
  _copy_uploads(replicated, cdir, rename_to=base_filenames)
407
  else:
408
+ yield (f"[ERROR] control_{i}: file count {len(uploads)} must be 1 or {len(base_filenames)}.", None)
409
  return
410
  control_dirs.append(folder_name)
411
  any_control = True
412
+ yield (f"[QIE] Copied {len(uploads)} control_{i} images to {cdir}", None)
413
 
414
  if not any_control:
415
+ yield ("[ERROR] At least one control folder is required for edit-plus training.", None)
416
  return
417
 
418
  # Prepare script with user parameters
 
422
  ]
423
 
424
  # Decide dataset_config path with fallback to runtime auto dir
425
+ ds_conf = str(Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml")
426
+
427
+ # Resolve models_root and output_dir_base at runtime
428
+ models_root = MODELS_ROOT_RUNTIME
429
+ out_base = os.path.join(AUTO_DIR_RUNTIME, "train_LoRA")
430
+ try:
431
+ os.makedirs(out_base, exist_ok=True)
432
+ except Exception:
433
+ pass
434
 
435
  tmp_script = _prepare_script(
436
  dataset_name=ds_name,
 
438
  data_root=DATA_ROOT_RUNTIME,
439
  image_folder=img_folder_name,
440
  control_folders=control_folders,
441
+ models_root=models_root,
442
+ output_dir_base=out_base,
443
  dataset_config=ds_conf,
444
  override_max_epochs=max_epochs if max_epochs and max_epochs > 0 else None,
445
  override_save_every=save_every if save_every and save_every > 0 else None,
 
448
 
449
 
450
  shell = _pick_shell()
451
+ yield (f"[QIE] Using shell: {shell}", None)
452
+ yield (f"[QIE] Running script: {tmp_script}", None)
453
 
454
  # Run and stream output
455
  proc = subprocess.Popen(
 
463
  try:
464
  assert proc.stdout is not None
465
  for line in proc.stdout:
466
+ yield (line.rstrip("\n"), None)
467
  finally:
468
  code = proc.wait()
469
+ # Try to locate latest LoRA file for download
470
+ lora_path = None
471
+ try:
472
+ out_dir = os.path.join(out_base, output_name.strip())
473
+ if os.path.isdir(out_dir):
474
+ cand = []
475
+ for root, _, files in os.walk(out_dir):
476
+ for fn in files:
477
+ if fn.lower().endswith(".safetensors"):
478
+ full = os.path.join(root, fn)
479
+ cand.append((os.path.getmtime(full), full))
480
+ if cand:
481
+ cand.sort()
482
+ lora_path = cand[-1][1]
483
+ except Exception:
484
+ pass
485
+ yield (f"[QIE] Exit code: {code}", lora_path)
486
 
487
 
488
  def build_ui() -> gr.Blocks:
 
519
  with gr.Row():
520
  ctrl7_files = gr.File(label="Upload control_7 images", file_count="multiple", type="filepath")
521
 
522
+ # Models root / OUTPUT_DIR_BASE / DATASET_CONFIG are auto-resolved at runtime; no user input needed.
 
 
 
523
 
524
  run_btn = gr.Button("Start Training", variant="primary")
525
  logs = gr.Textbox(label="Logs", lines=20)
526
+ lora_file = gr.File(label="Download LoRA", interactive=False)
527
 
528
  with gr.Row():
529
  max_epochs = gr.Number(label="Max epochs (this run)", value=10, precision=0)
 
534
  inputs=[
535
  output_name, caption, images_input,
536
  ctrl0_files, ctrl1_files, ctrl2_files, ctrl3_files, ctrl4_files, ctrl5_files, ctrl6_files, ctrl7_files,
 
537
  max_epochs, save_every,
538
  ],
539
+ outputs=[logs, lora_file],
540
  )
541
 
542
  return demo