yeq6x commited on
Commit
325c528
·
1 Parent(s): f440db9

Add dataset configuration update functionality in app.py

Browse files

Implement a new function, _update_dataset_toml, to modify dataset TOML files in-place, allowing updates to resolution and batch size settings. Integrate this function into run_training to ensure dataset configurations are updated based on user inputs. Enhance the UI to include fields for image resolution and control resolution, improving user experience and flexibility in training configurations.

Files changed (1) hide show
  1. app.py +104 -1
app.py CHANGED
@@ -78,6 +78,68 @@ def _ensure_workspace_auto_files() -> None:
78
  pass
79
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  def _ensure_dir_writable(path: str) -> str:
82
  try:
83
  os.makedirs(path, exist_ok=True)
@@ -245,6 +307,7 @@ def _prepare_script(
245
  override_learning_rate: Optional[str] = None,
246
  override_network_dim: Optional[int] = None,
247
  override_seed: Optional[int] = None,
 
248
  ) -> Path:
249
  """Create a temporary copy of train_QIE.sh with injected variables.
250
 
@@ -366,6 +429,15 @@ def _prepare_script(
366
  if override_seed is not None:
367
  txt = re.sub(r"--seed\s+\d+", f"--seed {override_seed}", txt)
368
 
 
 
 
 
 
 
 
 
 
369
  # Prefer overriding variable definitions at top of script (safer than CLI regex)
370
  def _set_var(name: str, value: str) -> None:
371
  nonlocal txt
@@ -533,6 +605,12 @@ def run_training(
533
  ctrl7_suffix: str,
534
  learning_rate: str,
535
  network_dim: int,
 
 
 
 
 
 
536
  seed: int,
537
  max_epochs: int,
538
  save_every: int,
@@ -614,6 +692,20 @@ def run_training(
614
  # Decide dataset_config path with fallback to runtime auto dir
615
  ds_conf = str(Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml")
616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
617
  # Resolve models_root and set output_dir_base to the unique dataset dir
618
  models_root = MODELS_ROOT_RUNTIME
619
  out_base = ds_dir
@@ -640,6 +732,7 @@ def run_training(
640
  control_suffixes=[ctrl0_suffix, ctrl1_suffix, ctrl2_suffix, ctrl3_suffix, ctrl4_suffix, ctrl5_suffix, ctrl6_suffix, ctrl7_suffix],
641
  override_learning_rate=(learning_rate or None),
642
  override_network_dim=int(network_dim) if network_dim is not None else None,
 
643
  override_seed=int(seed) if seed is not None else None,
644
  )
645
 
@@ -730,10 +823,18 @@ def build_ui() -> gr.Blocks:
730
  with gr.Row():
731
  lr_input = gr.Textbox(label="Learning rate", value="1e-3")
732
  dim_input = gr.Number(label="Network dim", value=4, precision=0)
 
733
  seed_input = gr.Number(label="Seed", value=42, precision=0)
734
  max_epochs = gr.Number(label="Max epochs", value=100, precision=0)
735
  save_every = gr.Number(label="Save every N epochs", value=10, precision=0)
736
 
 
 
 
 
 
 
 
737
  with gr.Accordion("Target Image", elem_classes=["pad-section_0"]):
738
  with gr.Group():
739
  with gr.Row():
@@ -859,7 +960,9 @@ def build_ui() -> gr.Blocks:
859
  ctrl5_files, ctrl5_prefix, ctrl5_suffix,
860
  ctrl6_files, ctrl6_prefix, ctrl6_suffix,
861
  ctrl7_files, ctrl7_prefix, ctrl7_suffix,
862
- lr_input, dim_input, seed_input, max_epochs, save_every,
 
 
863
  ],
864
  outputs=[logs, ckpt_files],
865
  )
 
78
  pass
79
 
80
 
81
+ def _update_dataset_toml(
82
+ path: str,
83
+ *,
84
+ img_res_w: Optional[int] = None,
85
+ img_res_h: Optional[int] = None,
86
+ train_batch_size: Optional[int] = None,
87
+ control_res_w: Optional[int] = None,
88
+ control_res_h: Optional[int] = None,
89
+ ) -> None:
90
+ """Update dataset TOML for resolution/batch/control resolution in-place.
91
+
92
+ - Updates [general] resolution and batch_size if provided.
93
+ - Updates first [[datasets]] qwen_image_edit_control_resolution if provided.
94
+ - Creates sections/keys if missing.
95
+ """
96
+ try:
97
+ txt = Path(path).read_text(encoding="utf-8")
98
+ except Exception:
99
+ return
100
+
101
+ def _set_in_general(block: str, key: str, value_line: str) -> str:
102
+ import re as _re
103
+ if _re.search(rf"(?m)^\s*{_re.escape(key)}\s*=", block):
104
+ block = _re.sub(rf"(?m)^\s*{_re.escape(key)}\s*=.*$", value_line, block)
105
+ else:
106
+ block = block.rstrip() + "\n" + value_line + "\n"
107
+ return block
108
+
109
+ import re
110
+ m = re.search(r"(?ms)^\[general\]\s*(.*?)(?=^\[|\Z)", txt)
111
+ if not m:
112
+ gen = "[general]\n"
113
+ if img_res_w and img_res_h:
114
+ gen += f"resolution = [{int(img_res_w)}, {int(img_res_h)}]\n"
115
+ if train_batch_size is not None:
116
+ gen += f"batch_size = {int(train_batch_size)}\n"
117
+ txt = gen + "\n" + txt
118
+ else:
119
+ head, block, tail = txt[:m.start(1)], m.group(1), txt[m.end(1):]
120
+ if img_res_w and img_res_h:
121
+ block = _set_in_general(block, "resolution", f"resolution = [{int(img_res_w)}, {int(img_res_h)}]")
122
+ if train_batch_size is not None:
123
+ block = _set_in_general(block, "batch_size", f"batch_size = {int(train_batch_size)}")
124
+ txt = head + block + tail
125
+
126
+ if control_res_w and control_res_h:
127
+ m2 = re.search(r"(?ms)^\[\[datasets\]\]\s*(.*?)(?=^\[\[|\Z)", txt)
128
+ if m2:
129
+ head, block, tail = txt[:m2.start(1)], m2.group(1), txt[m2.end(1):]
130
+ line = f"qwen_image_edit_control_resolution = [{int(control_res_w)}, {int(control_res_h)}]"
131
+ if re.search(r"(?m)^\s*qwen_image_edit_control_resolution\s*=", block):
132
+ block = re.sub(r"(?m)^\s*qwen_image_edit_control_resolution\s*=.*$", line, block)
133
+ else:
134
+ block = block.rstrip() + "\n" + line + "\n"
135
+ txt = head + block + tail
136
+
137
+ try:
138
+ Path(path).write_text(txt, encoding="utf-8")
139
+ except Exception:
140
+ pass
141
+
142
+
143
  def _ensure_dir_writable(path: str) -> str:
144
  try:
145
  os.makedirs(path, exist_ok=True)
 
307
  override_learning_rate: Optional[str] = None,
308
  override_network_dim: Optional[int] = None,
309
  override_seed: Optional[int] = None,
310
+ override_te_cache_bs: Optional[int] = None,
311
  ) -> Path:
312
  """Create a temporary copy of train_QIE.sh with injected variables.
313
 
 
429
  if override_seed is not None:
430
  txt = re.sub(r"--seed\s+\d+", f"--seed {override_seed}", txt)
431
 
432
+ # Optionally override text-encoder cache batch size
433
+ if override_te_cache_bs is not None and override_te_cache_bs > 0:
434
+ txt = re.sub(
435
+ r"(qwen_image_cache_text_encoder_outputs\.py[^\n]*--batch_size\s+)\d+",
436
+ rf"\g<1>{int(override_te_cache_bs)}",
437
+ txt,
438
+ flags=re.MULTILINE,
439
+ )
440
+
441
  # Prefer overriding variable definitions at top of script (safer than CLI regex)
442
  def _set_var(name: str, value: str) -> None:
443
  nonlocal txt
 
605
  ctrl7_suffix: str,
606
  learning_rate: str,
607
  network_dim: int,
608
+ train_res_w: int,
609
+ train_res_h: int,
610
+ train_batch_size: int,
611
+ control_res_w: int,
612
+ control_res_h: int,
613
+ te_cache_batch_size: int,
614
  seed: int,
615
  max_epochs: int,
616
  save_every: int,
 
692
  # Decide dataset_config path with fallback to runtime auto dir
693
  ds_conf = str(Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml")
694
 
695
+ # Update dataset config with requested resolution/batch settings
696
+ try:
697
+ _update_dataset_toml(
698
+ ds_conf,
699
+ img_res_w=int(train_res_w) if train_res_w else None,
700
+ img_res_h=int(train_res_h) if train_res_h else None,
701
+ train_batch_size=int(train_batch_size) if train_batch_size else None,
702
+ control_res_w=int(control_res_w) if control_res_w else None,
703
+ control_res_h=int(control_res_h) if control_res_h else None,
704
+ )
705
+ log_buf += f"[QIE] Updated dataset config: resolution=({train_res_w},{train_res_h}), batch_size={train_batch_size}, control_res=({control_res_w},{control_res_h})\n"
706
+ except Exception as e:
707
+ log_buf += f"[QIE] WARN: failed to update dataset config: {e}\n"
708
+
709
  # Resolve models_root and set output_dir_base to the unique dataset dir
710
  models_root = MODELS_ROOT_RUNTIME
711
  out_base = ds_dir
 
732
  control_suffixes=[ctrl0_suffix, ctrl1_suffix, ctrl2_suffix, ctrl3_suffix, ctrl4_suffix, ctrl5_suffix, ctrl6_suffix, ctrl7_suffix],
733
  override_learning_rate=(learning_rate or None),
734
  override_network_dim=int(network_dim) if network_dim is not None else None,
735
+ override_te_cache_bs=int(te_cache_batch_size) if te_cache_batch_size else None,
736
  override_seed=int(seed) if seed is not None else None,
737
  )
738
 
 
823
  with gr.Row():
824
  lr_input = gr.Textbox(label="Learning rate", value="1e-3")
825
  dim_input = gr.Number(label="Network dim", value=4, precision=0)
826
+ train_bs = gr.Number(label="Batch size (dataset)", value=1, precision=0)
827
  seed_input = gr.Number(label="Seed", value=42, precision=0)
828
  max_epochs = gr.Number(label="Max epochs", value=100, precision=0)
829
  save_every = gr.Number(label="Save every N epochs", value=10, precision=0)
830
 
831
+ with gr.Row():
832
+ tr_w = gr.Number(label="Image resolution W", value=1024, precision=0)
833
+ tr_h = gr.Number(label="Image resolution H", value=1024, precision=0)
834
+ cr_w = gr.Number(label="Control resolution W", value=1024, precision=0)
835
+ cr_h = gr.Number(label="Control resolution H", value=1024, precision=0)
836
+ te_bs = gr.Number(label="TE cache batch size", value=16, precision=0)
837
+
838
  with gr.Accordion("Target Image", elem_classes=["pad-section_0"]):
839
  with gr.Group():
840
  with gr.Row():
 
960
  ctrl5_files, ctrl5_prefix, ctrl5_suffix,
961
  ctrl6_files, ctrl6_prefix, ctrl6_suffix,
962
  ctrl7_files, ctrl7_prefix, ctrl7_suffix,
963
+ lr_input, dim_input,
964
+ tr_w, tr_h, train_bs, cr_w, cr_h, te_bs,
965
+ seed_input, max_epochs, save_every,
966
  ],
967
  outputs=[logs, ckpt_files],
968
  )