Spaces:
Running
on
Zero
Running
on
Zero
Add dataset configuration update functionality in app.py
Browse filesImplement a new function, _update_dataset_toml, to modify dataset TOML files in-place, allowing updates to resolution and batch size settings. Integrate this function into run_training to ensure dataset configurations are updated based on user inputs. Enhance the UI to include fields for image resolution and control resolution, improving user experience and flexibility in training configurations.
app.py
CHANGED
|
@@ -78,6 +78,68 @@ def _ensure_workspace_auto_files() -> None:
|
|
| 78 |
pass
|
| 79 |
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
def _ensure_dir_writable(path: str) -> str:
|
| 82 |
try:
|
| 83 |
os.makedirs(path, exist_ok=True)
|
|
@@ -245,6 +307,7 @@ def _prepare_script(
|
|
| 245 |
override_learning_rate: Optional[str] = None,
|
| 246 |
override_network_dim: Optional[int] = None,
|
| 247 |
override_seed: Optional[int] = None,
|
|
|
|
| 248 |
) -> Path:
|
| 249 |
"""Create a temporary copy of train_QIE.sh with injected variables.
|
| 250 |
|
|
@@ -366,6 +429,15 @@ def _prepare_script(
|
|
| 366 |
if override_seed is not None:
|
| 367 |
txt = re.sub(r"--seed\s+\d+", f"--seed {override_seed}", txt)
|
| 368 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
# Prefer overriding variable definitions at top of script (safer than CLI regex)
|
| 370 |
def _set_var(name: str, value: str) -> None:
|
| 371 |
nonlocal txt
|
|
@@ -533,6 +605,12 @@ def run_training(
|
|
| 533 |
ctrl7_suffix: str,
|
| 534 |
learning_rate: str,
|
| 535 |
network_dim: int,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
seed: int,
|
| 537 |
max_epochs: int,
|
| 538 |
save_every: int,
|
|
@@ -614,6 +692,20 @@ def run_training(
|
|
| 614 |
# Decide dataset_config path with fallback to runtime auto dir
|
| 615 |
ds_conf = str(Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml")
|
| 616 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 617 |
# Resolve models_root and set output_dir_base to the unique dataset dir
|
| 618 |
models_root = MODELS_ROOT_RUNTIME
|
| 619 |
out_base = ds_dir
|
|
@@ -640,6 +732,7 @@ def run_training(
|
|
| 640 |
control_suffixes=[ctrl0_suffix, ctrl1_suffix, ctrl2_suffix, ctrl3_suffix, ctrl4_suffix, ctrl5_suffix, ctrl6_suffix, ctrl7_suffix],
|
| 641 |
override_learning_rate=(learning_rate or None),
|
| 642 |
override_network_dim=int(network_dim) if network_dim is not None else None,
|
|
|
|
| 643 |
override_seed=int(seed) if seed is not None else None,
|
| 644 |
)
|
| 645 |
|
|
@@ -730,10 +823,18 @@ def build_ui() -> gr.Blocks:
|
|
| 730 |
with gr.Row():
|
| 731 |
lr_input = gr.Textbox(label="Learning rate", value="1e-3")
|
| 732 |
dim_input = gr.Number(label="Network dim", value=4, precision=0)
|
|
|
|
| 733 |
seed_input = gr.Number(label="Seed", value=42, precision=0)
|
| 734 |
max_epochs = gr.Number(label="Max epochs", value=100, precision=0)
|
| 735 |
save_every = gr.Number(label="Save every N epochs", value=10, precision=0)
|
| 736 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 737 |
with gr.Accordion("Target Image", elem_classes=["pad-section_0"]):
|
| 738 |
with gr.Group():
|
| 739 |
with gr.Row():
|
|
@@ -859,7 +960,9 @@ def build_ui() -> gr.Blocks:
|
|
| 859 |
ctrl5_files, ctrl5_prefix, ctrl5_suffix,
|
| 860 |
ctrl6_files, ctrl6_prefix, ctrl6_suffix,
|
| 861 |
ctrl7_files, ctrl7_prefix, ctrl7_suffix,
|
| 862 |
-
lr_input, dim_input,
|
|
|
|
|
|
|
| 863 |
],
|
| 864 |
outputs=[logs, ckpt_files],
|
| 865 |
)
|
|
|
|
| 78 |
pass
|
| 79 |
|
| 80 |
|
| 81 |
+
def _update_dataset_toml(
|
| 82 |
+
path: str,
|
| 83 |
+
*,
|
| 84 |
+
img_res_w: Optional[int] = None,
|
| 85 |
+
img_res_h: Optional[int] = None,
|
| 86 |
+
train_batch_size: Optional[int] = None,
|
| 87 |
+
control_res_w: Optional[int] = None,
|
| 88 |
+
control_res_h: Optional[int] = None,
|
| 89 |
+
) -> None:
|
| 90 |
+
"""Update dataset TOML for resolution/batch/control resolution in-place.
|
| 91 |
+
|
| 92 |
+
- Updates [general] resolution and batch_size if provided.
|
| 93 |
+
- Updates first [[datasets]] qwen_image_edit_control_resolution if provided.
|
| 94 |
+
- Creates sections/keys if missing.
|
| 95 |
+
"""
|
| 96 |
+
try:
|
| 97 |
+
txt = Path(path).read_text(encoding="utf-8")
|
| 98 |
+
except Exception:
|
| 99 |
+
return
|
| 100 |
+
|
| 101 |
+
def _set_in_general(block: str, key: str, value_line: str) -> str:
|
| 102 |
+
import re as _re
|
| 103 |
+
if _re.search(rf"(?m)^\s*{_re.escape(key)}\s*=", block):
|
| 104 |
+
block = _re.sub(rf"(?m)^\s*{_re.escape(key)}\s*=.*$", value_line, block)
|
| 105 |
+
else:
|
| 106 |
+
block = block.rstrip() + "\n" + value_line + "\n"
|
| 107 |
+
return block
|
| 108 |
+
|
| 109 |
+
import re
|
| 110 |
+
m = re.search(r"(?ms)^\[general\]\s*(.*?)(?=^\[|\Z)", txt)
|
| 111 |
+
if not m:
|
| 112 |
+
gen = "[general]\n"
|
| 113 |
+
if img_res_w and img_res_h:
|
| 114 |
+
gen += f"resolution = [{int(img_res_w)}, {int(img_res_h)}]\n"
|
| 115 |
+
if train_batch_size is not None:
|
| 116 |
+
gen += f"batch_size = {int(train_batch_size)}\n"
|
| 117 |
+
txt = gen + "\n" + txt
|
| 118 |
+
else:
|
| 119 |
+
head, block, tail = txt[:m.start(1)], m.group(1), txt[m.end(1):]
|
| 120 |
+
if img_res_w and img_res_h:
|
| 121 |
+
block = _set_in_general(block, "resolution", f"resolution = [{int(img_res_w)}, {int(img_res_h)}]")
|
| 122 |
+
if train_batch_size is not None:
|
| 123 |
+
block = _set_in_general(block, "batch_size", f"batch_size = {int(train_batch_size)}")
|
| 124 |
+
txt = head + block + tail
|
| 125 |
+
|
| 126 |
+
if control_res_w and control_res_h:
|
| 127 |
+
m2 = re.search(r"(?ms)^\[\[datasets\]\]\s*(.*?)(?=^\[\[|\Z)", txt)
|
| 128 |
+
if m2:
|
| 129 |
+
head, block, tail = txt[:m2.start(1)], m2.group(1), txt[m2.end(1):]
|
| 130 |
+
line = f"qwen_image_edit_control_resolution = [{int(control_res_w)}, {int(control_res_h)}]"
|
| 131 |
+
if re.search(r"(?m)^\s*qwen_image_edit_control_resolution\s*=", block):
|
| 132 |
+
block = re.sub(r"(?m)^\s*qwen_image_edit_control_resolution\s*=.*$", line, block)
|
| 133 |
+
else:
|
| 134 |
+
block = block.rstrip() + "\n" + line + "\n"
|
| 135 |
+
txt = head + block + tail
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
Path(path).write_text(txt, encoding="utf-8")
|
| 139 |
+
except Exception:
|
| 140 |
+
pass
|
| 141 |
+
|
| 142 |
+
|
| 143 |
def _ensure_dir_writable(path: str) -> str:
|
| 144 |
try:
|
| 145 |
os.makedirs(path, exist_ok=True)
|
|
|
|
| 307 |
override_learning_rate: Optional[str] = None,
|
| 308 |
override_network_dim: Optional[int] = None,
|
| 309 |
override_seed: Optional[int] = None,
|
| 310 |
+
override_te_cache_bs: Optional[int] = None,
|
| 311 |
) -> Path:
|
| 312 |
"""Create a temporary copy of train_QIE.sh with injected variables.
|
| 313 |
|
|
|
|
| 429 |
if override_seed is not None:
|
| 430 |
txt = re.sub(r"--seed\s+\d+", f"--seed {override_seed}", txt)
|
| 431 |
|
| 432 |
+
# Optionally override text-encoder cache batch size
|
| 433 |
+
if override_te_cache_bs is not None and override_te_cache_bs > 0:
|
| 434 |
+
txt = re.sub(
|
| 435 |
+
r"(qwen_image_cache_text_encoder_outputs\.py[^\n]*--batch_size\s+)\d+",
|
| 436 |
+
rf"\g<1>{int(override_te_cache_bs)}",
|
| 437 |
+
txt,
|
| 438 |
+
flags=re.MULTILINE,
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
# Prefer overriding variable definitions at top of script (safer than CLI regex)
|
| 442 |
def _set_var(name: str, value: str) -> None:
|
| 443 |
nonlocal txt
|
|
|
|
| 605 |
ctrl7_suffix: str,
|
| 606 |
learning_rate: str,
|
| 607 |
network_dim: int,
|
| 608 |
+
train_res_w: int,
|
| 609 |
+
train_res_h: int,
|
| 610 |
+
train_batch_size: int,
|
| 611 |
+
control_res_w: int,
|
| 612 |
+
control_res_h: int,
|
| 613 |
+
te_cache_batch_size: int,
|
| 614 |
seed: int,
|
| 615 |
max_epochs: int,
|
| 616 |
save_every: int,
|
|
|
|
| 692 |
# Decide dataset_config path with fallback to runtime auto dir
|
| 693 |
ds_conf = str(Path(AUTO_DIR_RUNTIME) / "dataset_QIE.toml")
|
| 694 |
|
| 695 |
+
# Update dataset config with requested resolution/batch settings
|
| 696 |
+
try:
|
| 697 |
+
_update_dataset_toml(
|
| 698 |
+
ds_conf,
|
| 699 |
+
img_res_w=int(train_res_w) if train_res_w else None,
|
| 700 |
+
img_res_h=int(train_res_h) if train_res_h else None,
|
| 701 |
+
train_batch_size=int(train_batch_size) if train_batch_size else None,
|
| 702 |
+
control_res_w=int(control_res_w) if control_res_w else None,
|
| 703 |
+
control_res_h=int(control_res_h) if control_res_h else None,
|
| 704 |
+
)
|
| 705 |
+
log_buf += f"[QIE] Updated dataset config: resolution=({train_res_w},{train_res_h}), batch_size={train_batch_size}, control_res=({control_res_w},{control_res_h})\n"
|
| 706 |
+
except Exception as e:
|
| 707 |
+
log_buf += f"[QIE] WARN: failed to update dataset config: {e}\n"
|
| 708 |
+
|
| 709 |
# Resolve models_root and set output_dir_base to the unique dataset dir
|
| 710 |
models_root = MODELS_ROOT_RUNTIME
|
| 711 |
out_base = ds_dir
|
|
|
|
| 732 |
control_suffixes=[ctrl0_suffix, ctrl1_suffix, ctrl2_suffix, ctrl3_suffix, ctrl4_suffix, ctrl5_suffix, ctrl6_suffix, ctrl7_suffix],
|
| 733 |
override_learning_rate=(learning_rate or None),
|
| 734 |
override_network_dim=int(network_dim) if network_dim is not None else None,
|
| 735 |
+
override_te_cache_bs=int(te_cache_batch_size) if te_cache_batch_size else None,
|
| 736 |
override_seed=int(seed) if seed is not None else None,
|
| 737 |
)
|
| 738 |
|
|
|
|
| 823 |
with gr.Row():
|
| 824 |
lr_input = gr.Textbox(label="Learning rate", value="1e-3")
|
| 825 |
dim_input = gr.Number(label="Network dim", value=4, precision=0)
|
| 826 |
+
train_bs = gr.Number(label="Batch size (dataset)", value=1, precision=0)
|
| 827 |
seed_input = gr.Number(label="Seed", value=42, precision=0)
|
| 828 |
max_epochs = gr.Number(label="Max epochs", value=100, precision=0)
|
| 829 |
save_every = gr.Number(label="Save every N epochs", value=10, precision=0)
|
| 830 |
|
| 831 |
+
with gr.Row():
|
| 832 |
+
tr_w = gr.Number(label="Image resolution W", value=1024, precision=0)
|
| 833 |
+
tr_h = gr.Number(label="Image resolution H", value=1024, precision=0)
|
| 834 |
+
cr_w = gr.Number(label="Control resolution W", value=1024, precision=0)
|
| 835 |
+
cr_h = gr.Number(label="Control resolution H", value=1024, precision=0)
|
| 836 |
+
te_bs = gr.Number(label="TE cache batch size", value=16, precision=0)
|
| 837 |
+
|
| 838 |
with gr.Accordion("Target Image", elem_classes=["pad-section_0"]):
|
| 839 |
with gr.Group():
|
| 840 |
with gr.Row():
|
|
|
|
| 960 |
ctrl5_files, ctrl5_prefix, ctrl5_suffix,
|
| 961 |
ctrl6_files, ctrl6_prefix, ctrl6_suffix,
|
| 962 |
ctrl7_files, ctrl7_prefix, ctrl7_suffix,
|
| 963 |
+
lr_input, dim_input,
|
| 964 |
+
tr_w, tr_h, train_bs, cr_w, cr_h, te_bs,
|
| 965 |
+
seed_input, max_epochs, save_every,
|
| 966 |
],
|
| 967 |
outputs=[logs, ckpt_files],
|
| 968 |
)
|