yeq6x commited on
Commit
94acb06
·
1 Parent(s): 6b04281

Enhance training functionality with ZeroGPU support and UI adjustments. Added options to override max epochs and save frequency, and implemented GPU request handling for Spaces compatibility.

Browse files
Files changed (1) hide show
  1. app.py +47 -1
app.py CHANGED
@@ -9,9 +9,10 @@ from pathlib import Path
9
  from typing import Dict, Iterable, List, Optional
10
 
11
  import gradio as gr
12
-
13
  import spaces
14
 
 
 
15
  # Local modules
16
  from download_qwen_image_models import download_all_models, DEFAULT_MODELS_DIR
17
 
@@ -73,6 +74,8 @@ def _prepare_script(
73
  models_root: str,
74
  output_dir_base: Optional[str] = None,
75
  dataset_config: Optional[str] = None,
 
 
76
  ) -> Path:
77
  """Create a temporary copy of train_QIE.sh with injected variables.
78
 
@@ -137,6 +140,24 @@ def _prepare_script(
137
  txt = _replace_model_path(txt, "text_encoder", "text_encoder/qwen_2.5_vl_7b.safetensors")
138
  txt = _replace_model_path(txt, "dit", "dit/qwen_image_edit_2509_bf16.safetensors")
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  # Write to a temp file alongside this repo for easier inspection
141
  run_dir = TRAINING_DIR / ".gradio_runs"
142
  run_dir.mkdir(parents=True, exist_ok=True)
@@ -199,6 +220,7 @@ def _startup_clone_musubi_tuner() -> None:
199
  print(f"[QIE] Clone failed: {e}")
200
 
201
 
 
202
  def run_training(
203
  dataset_name: str,
204
  caption: str,
@@ -215,6 +237,8 @@ def run_training(
215
  models_root: str,
216
  output_dir_base: str,
217
  dataset_config: str,
 
 
218
  ) -> Iterable[str]:
219
  # Basic validation
220
  if not dataset_name.strip():
@@ -241,8 +265,11 @@ def run_training(
241
  models_root=models_root.strip() or DEFAULT_MODELS_ROOT,
242
  output_dir_base=(output_dir_base.strip() or None),
243
  dataset_config=(dataset_config.strip() or None),
 
 
244
  )
245
 
 
246
  shell = _pick_shell()
247
  yield f"[QIE] Using shell: {shell}"
248
  yield f"[QIE] Running script: {tmp_script}"
@@ -300,12 +327,17 @@ def build_ui() -> gr.Blocks:
300
  run_btn = gr.Button("Start Training", variant="primary")
301
  logs = gr.Textbox(label="Logs", lines=20)
302
 
 
 
 
 
303
  run_btn.click(
304
  fn=run_training,
305
  inputs=[
306
  dataset_name, caption, data_root, image_folder,
307
  c0, c1, c2, c3, c4, c5, c6, c7,
308
  models_root, output_dir_base, dataset_config,
 
309
  ],
310
  outputs=logs,
311
  )
@@ -313,6 +345,11 @@ def build_ui() -> gr.Blocks:
313
  return demo
314
 
315
 
 
 
 
 
 
316
  def _startup_download_models() -> None:
317
  models_dir = DEFAULT_MODELS_ROOT
318
  print(f"[QIE] Ensuring models in: {models_dir}")
@@ -323,6 +360,13 @@ def _startup_download_models() -> None:
323
 
324
 
325
  if __name__ == "__main__":
 
 
 
 
 
 
 
326
  # 1) Ensure musubi-tuner is cloned before anything else
327
  _startup_clone_musubi_tuner()
328
 
@@ -331,4 +375,6 @@ if __name__ == "__main__":
331
 
332
  # 3) Launch Gradio app
333
  ui = build_ui()
 
 
334
  ui.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
 
9
  from typing import Dict, Iterable, List, Optional
10
 
11
  import gradio as gr
 
12
  import spaces
13
 
14
+ # No Spaces GPU reservation to allow zero-GPU (CPU-only) usage
15
+
16
  # Local modules
17
  from download_qwen_image_models import download_all_models, DEFAULT_MODELS_DIR
18
 
 
74
  models_root: str,
75
  output_dir_base: Optional[str] = None,
76
  dataset_config: Optional[str] = None,
77
+ override_max_epochs: Optional[int] = None,
78
+ override_save_every: Optional[int] = None,
79
  ) -> Path:
80
  """Create a temporary copy of train_QIE.sh with injected variables.
81
 
 
140
  txt = _replace_model_path(txt, "text_encoder", "text_encoder/qwen_2.5_vl_7b.safetensors")
141
  txt = _replace_model_path(txt, "dit", "dit/qwen_image_edit_2509_bf16.safetensors")
142
 
143
+ # ZeroGPU compatibility: avoid spawning via 'accelerate launch'.
144
+ # Run the training module directly in-process so GPU stays attached
145
+ # to the same Python request context.
146
+ txt = re.sub(
147
+ r"\baccelerate\s+launch\s+src/musubi_tuner/qwen_image_train_network.py",
148
+ r"python src/musubi_tuner/qwen_image_train_network.py",
149
+ txt,
150
+ flags=re.MULTILINE,
151
+ )
152
+
153
+ # Optionally override epochs and save frequency for ZeroGPU time slicing
154
+ if override_max_epochs is not None and override_max_epochs > 0:
155
+ txt = re.sub(r"--max_train_epochs\s+\d+",
156
+ f"--max_train_epochs {override_max_epochs}", txt)
157
+ if override_save_every is not None and override_save_every > 0:
158
+ txt = re.sub(r"--save_every_n_epochs\s+\d+",
159
+ f"--save_every_n_epochs {override_save_every}", txt)
160
+
161
  # Write to a temp file alongside this repo for easier inspection
162
  run_dir = TRAINING_DIR / ".gradio_runs"
163
  run_dir.mkdir(parents=True, exist_ok=True)
 
220
  print(f"[QIE] Clone failed: {e}")
221
 
222
 
223
+ @spaces.GPU(duration=7200)
224
  def run_training(
225
  dataset_name: str,
226
  caption: str,
 
237
  models_root: str,
238
  output_dir_base: str,
239
  dataset_config: str,
240
+ max_epochs: int,
241
+ save_every: int,
242
  ) -> Iterable[str]:
243
  # Basic validation
244
  if not dataset_name.strip():
 
265
  models_root=models_root.strip() or DEFAULT_MODELS_ROOT,
266
  output_dir_base=(output_dir_base.strip() or None),
267
  dataset_config=(dataset_config.strip() or None),
268
+ override_max_epochs=max_epochs if max_epochs and max_epochs > 0 else None,
269
+ override_save_every=save_every if save_every and save_every > 0 else None,
270
  )
271
 
272
+
273
  shell = _pick_shell()
274
  yield f"[QIE] Using shell: {shell}"
275
  yield f"[QIE] Running script: {tmp_script}"
 
327
  run_btn = gr.Button("Start Training", variant="primary")
328
  logs = gr.Textbox(label="Logs", lines=20)
329
 
330
+ with gr.Row():
331
+ max_epochs = gr.Number(label="Max epochs (this run)", value=10, precision=0)
332
+ save_every = gr.Number(label="Save every N epochs", value=5, precision=0)
333
+
334
  run_btn.click(
335
  fn=run_training,
336
  inputs=[
337
  dataset_name, caption, data_root, image_folder,
338
  c0, c1, c2, c3, c4, c5, c6, c7,
339
  models_root, output_dir_base, dataset_config,
340
+ max_epochs, save_every,
341
  ],
342
  outputs=logs,
343
  )
 
345
  return demo
346
 
347
 
348
+ @spaces.GPU(duration=600)
349
+ def _request_gpu_on_startup() -> str:
350
+ return "gpu-requested"
351
+
352
+
353
  def _startup_download_models() -> None:
354
  models_dir = DEFAULT_MODELS_ROOT
355
  print(f"[QIE] Ensuring models in: {models_dir}")
 
360
 
361
 
362
  if __name__ == "__main__":
363
+ # 0) Request GPU immediately for Spaces dynamic hardware
364
+ try:
365
+ tag = _request_gpu_on_startup()
366
+ print(f"[QIE] Spaces GPU tag: {tag}")
367
+ except Exception as e:
368
+ print(f"[QIE] GPU request skipped or failed: {e}")
369
+
370
  # 1) Ensure musubi-tuner is cloned before anything else
371
  _startup_clone_musubi_tuner()
372
 
 
375
 
376
  # 3) Launch Gradio app
377
  ui = build_ui()
378
+ # Limit concurrency (training is heavy). Enable queue for Spaces compatibility.
379
+ ui = ui.queue(concurrency_count=1, max_size=16)
380
  ui.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))