Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Enhance training functionality with ZeroGPU support and UI adjustments. Added options to override max epochs and save frequency, and implemented GPU request handling for Spaces compatibility.
Browse files
    	
        app.py
    CHANGED
    
    | 
         @@ -9,9 +9,10 @@ from pathlib import Path 
     | 
|
| 9 | 
         
             
            from typing import Dict, Iterable, List, Optional
         
     | 
| 10 | 
         | 
| 11 | 
         
             
            import gradio as gr
         
     | 
| 12 | 
         
            -
             
     | 
| 13 | 
         
             
            import spaces
         
     | 
| 14 | 
         | 
| 
         | 
|
| 
         | 
|
| 15 | 
         
             
            # Local modules
         
     | 
| 16 | 
         
             
            from download_qwen_image_models import download_all_models, DEFAULT_MODELS_DIR
         
     | 
| 17 | 
         | 
| 
         @@ -73,6 +74,8 @@ def _prepare_script( 
     | 
|
| 73 | 
         
             
                models_root: str,
         
     | 
| 74 | 
         
             
                output_dir_base: Optional[str] = None,
         
     | 
| 75 | 
         
             
                dataset_config: Optional[str] = None,
         
     | 
| 
         | 
|
| 
         | 
|
| 76 | 
         
             
            ) -> Path:
         
     | 
| 77 | 
         
             
                """Create a temporary copy of train_QIE.sh with injected variables.
         
     | 
| 78 | 
         | 
| 
         @@ -137,6 +140,24 @@ def _prepare_script( 
     | 
|
| 137 | 
         
             
                txt = _replace_model_path(txt, "text_encoder", "text_encoder/qwen_2.5_vl_7b.safetensors")
         
     | 
| 138 | 
         
             
                txt = _replace_model_path(txt, "dit", "dit/qwen_image_edit_2509_bf16.safetensors")
         
     | 
| 139 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 140 | 
         
             
                # Write to a temp file alongside this repo for easier inspection
         
     | 
| 141 | 
         
             
                run_dir = TRAINING_DIR / ".gradio_runs"
         
     | 
| 142 | 
         
             
                run_dir.mkdir(parents=True, exist_ok=True)
         
     | 
| 
         @@ -199,6 +220,7 @@ def _startup_clone_musubi_tuner() -> None: 
     | 
|
| 199 | 
         
             
                    print(f"[QIE] Clone failed: {e}")
         
     | 
| 200 | 
         | 
| 201 | 
         | 
| 
         | 
|
| 202 | 
         
             
            def run_training(
         
     | 
| 203 | 
         
             
                dataset_name: str,
         
     | 
| 204 | 
         
             
                caption: str,
         
     | 
| 
         @@ -215,6 +237,8 @@ def run_training( 
     | 
|
| 215 | 
         
             
                models_root: str,
         
     | 
| 216 | 
         
             
                output_dir_base: str,
         
     | 
| 217 | 
         
             
                dataset_config: str,
         
     | 
| 
         | 
|
| 
         | 
|
| 218 | 
         
             
            ) -> Iterable[str]:
         
     | 
| 219 | 
         
             
                # Basic validation
         
     | 
| 220 | 
         
             
                if not dataset_name.strip():
         
     | 
| 
         @@ -241,8 +265,11 @@ def run_training( 
     | 
|
| 241 | 
         
             
                    models_root=models_root.strip() or DEFAULT_MODELS_ROOT,
         
     | 
| 242 | 
         
             
                    output_dir_base=(output_dir_base.strip() or None),
         
     | 
| 243 | 
         
             
                    dataset_config=(dataset_config.strip() or None),
         
     | 
| 
         | 
|
| 
         | 
|
| 244 | 
         
             
                )
         
     | 
| 245 | 
         | 
| 
         | 
|
| 246 | 
         
             
                shell = _pick_shell()
         
     | 
| 247 | 
         
             
                yield f"[QIE] Using shell: {shell}"
         
     | 
| 248 | 
         
             
                yield f"[QIE] Running script: {tmp_script}"
         
     | 
| 
         @@ -300,12 +327,17 @@ def build_ui() -> gr.Blocks: 
     | 
|
| 300 | 
         
             
                    run_btn = gr.Button("Start Training", variant="primary")
         
     | 
| 301 | 
         
             
                    logs = gr.Textbox(label="Logs", lines=20)
         
     | 
| 302 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 303 | 
         
             
                    run_btn.click(
         
     | 
| 304 | 
         
             
                        fn=run_training,
         
     | 
| 305 | 
         
             
                        inputs=[
         
     | 
| 306 | 
         
             
                            dataset_name, caption, data_root, image_folder,
         
     | 
| 307 | 
         
             
                            c0, c1, c2, c3, c4, c5, c6, c7,
         
     | 
| 308 | 
         
             
                            models_root, output_dir_base, dataset_config,
         
     | 
| 
         | 
|
| 309 | 
         
             
                        ],
         
     | 
| 310 | 
         
             
                        outputs=logs,
         
     | 
| 311 | 
         
             
                    )
         
     | 
| 
         @@ -313,6 +345,11 @@ def build_ui() -> gr.Blocks: 
     | 
|
| 313 | 
         
             
                return demo
         
     | 
| 314 | 
         | 
| 315 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 316 | 
         
             
            def _startup_download_models() -> None:
         
     | 
| 317 | 
         
             
                models_dir = DEFAULT_MODELS_ROOT
         
     | 
| 318 | 
         
             
                print(f"[QIE] Ensuring models in: {models_dir}")
         
     | 
| 
         @@ -323,6 +360,13 @@ def _startup_download_models() -> None: 
     | 
|
| 323 | 
         | 
| 324 | 
         | 
| 325 | 
         
             
            if __name__ == "__main__":
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 326 | 
         
             
                # 1) Ensure musubi-tuner is cloned before anything else
         
     | 
| 327 | 
         
             
                _startup_clone_musubi_tuner()
         
     | 
| 328 | 
         | 
| 
         @@ -331,4 +375,6 @@ if __name__ == "__main__": 
     | 
|
| 331 | 
         | 
| 332 | 
         
             
                # 3) Launch Gradio app
         
     | 
| 333 | 
         
             
                ui = build_ui()
         
     | 
| 
         | 
|
| 
         | 
|
| 334 | 
         
             
                ui.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
         
     | 
| 
         | 
|
| 9 | 
         
             
            from typing import Dict, Iterable, List, Optional
         
     | 
| 10 | 
         | 
| 11 | 
         
             
            import gradio as gr
         
     | 
| 
         | 
|
| 12 | 
         
             
            import spaces
         
     | 
| 13 | 
         | 
| 14 | 
         
            +
            # No Spaces GPU reservation to allow zero-GPU (CPU-only) usage
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
             
            # Local modules
         
     | 
| 17 | 
         
             
            from download_qwen_image_models import download_all_models, DEFAULT_MODELS_DIR
         
     | 
| 18 | 
         | 
| 
         | 
|
| 74 | 
         
             
                models_root: str,
         
     | 
| 75 | 
         
             
                output_dir_base: Optional[str] = None,
         
     | 
| 76 | 
         
             
                dataset_config: Optional[str] = None,
         
     | 
| 77 | 
         
            +
                override_max_epochs: Optional[int] = None,
         
     | 
| 78 | 
         
            +
                override_save_every: Optional[int] = None,
         
     | 
| 79 | 
         
             
            ) -> Path:
         
     | 
| 80 | 
         
             
                """Create a temporary copy of train_QIE.sh with injected variables.
         
     | 
| 81 | 
         | 
| 
         | 
|
| 140 | 
         
             
                txt = _replace_model_path(txt, "text_encoder", "text_encoder/qwen_2.5_vl_7b.safetensors")
         
     | 
| 141 | 
         
             
                txt = _replace_model_path(txt, "dit", "dit/qwen_image_edit_2509_bf16.safetensors")
         
     | 
| 142 | 
         | 
| 143 | 
         
            +
                # ZeroGPU compatibility: avoid spawning via 'accelerate launch'.
         
     | 
| 144 | 
         
            +
                # Run the training module directly in-process so GPU stays attached
         
     | 
| 145 | 
         
            +
                # to the same Python request context.
         
     | 
| 146 | 
         
            +
                txt = re.sub(
         
     | 
| 147 | 
         
            +
                    r"\baccelerate\s+launch\s+src/musubi_tuner/qwen_image_train_network.py",
         
     | 
| 148 | 
         
            +
                    r"python src/musubi_tuner/qwen_image_train_network.py",
         
     | 
| 149 | 
         
            +
                    txt,
         
     | 
| 150 | 
         
            +
                    flags=re.MULTILINE,
         
     | 
| 151 | 
         
            +
                )
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                # Optionally override epochs and save frequency for ZeroGPU time slicing
         
     | 
| 154 | 
         
            +
                if override_max_epochs is not None and override_max_epochs > 0:
         
     | 
| 155 | 
         
            +
                    txt = re.sub(r"--max_train_epochs\s+\d+",
         
     | 
| 156 | 
         
            +
                                 f"--max_train_epochs {override_max_epochs}", txt)
         
     | 
| 157 | 
         
            +
                if override_save_every is not None and override_save_every > 0:
         
     | 
| 158 | 
         
            +
                    txt = re.sub(r"--save_every_n_epochs\s+\d+",
         
     | 
| 159 | 
         
            +
                                 f"--save_every_n_epochs {override_save_every}", txt)
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
             
                # Write to a temp file alongside this repo for easier inspection
         
     | 
| 162 | 
         
             
                run_dir = TRAINING_DIR / ".gradio_runs"
         
     | 
| 163 | 
         
             
                run_dir.mkdir(parents=True, exist_ok=True)
         
     | 
| 
         | 
|
| 220 | 
         
             
                    print(f"[QIE] Clone failed: {e}")
         
     | 
| 221 | 
         | 
| 222 | 
         | 
| 223 | 
         
            +
            @spaces.GPU(duration=7200)
         
     | 
| 224 | 
         
             
            def run_training(
         
     | 
| 225 | 
         
             
                dataset_name: str,
         
     | 
| 226 | 
         
             
                caption: str,
         
     | 
| 
         | 
|
| 237 | 
         
             
                models_root: str,
         
     | 
| 238 | 
         
             
                output_dir_base: str,
         
     | 
| 239 | 
         
             
                dataset_config: str,
         
     | 
| 240 | 
         
            +
                max_epochs: int,
         
     | 
| 241 | 
         
            +
                save_every: int,
         
     | 
| 242 | 
         
             
            ) -> Iterable[str]:
         
     | 
| 243 | 
         
             
                # Basic validation
         
     | 
| 244 | 
         
             
                if not dataset_name.strip():
         
     | 
| 
         | 
|
| 265 | 
         
             
                    models_root=models_root.strip() or DEFAULT_MODELS_ROOT,
         
     | 
| 266 | 
         
             
                    output_dir_base=(output_dir_base.strip() or None),
         
     | 
| 267 | 
         
             
                    dataset_config=(dataset_config.strip() or None),
         
     | 
| 268 | 
         
            +
                    override_max_epochs=max_epochs if max_epochs and max_epochs > 0 else None,
         
     | 
| 269 | 
         
            +
                    override_save_every=save_every if save_every and save_every > 0 else None,
         
     | 
| 270 | 
         
             
                )
         
     | 
| 271 | 
         | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
             
                shell = _pick_shell()
         
     | 
| 274 | 
         
             
                yield f"[QIE] Using shell: {shell}"
         
     | 
| 275 | 
         
             
                yield f"[QIE] Running script: {tmp_script}"
         
     | 
| 
         | 
|
| 327 | 
         
             
                    run_btn = gr.Button("Start Training", variant="primary")
         
     | 
| 328 | 
         
             
                    logs = gr.Textbox(label="Logs", lines=20)
         
     | 
| 329 | 
         | 
| 330 | 
         
            +
                    with gr.Row():
         
     | 
| 331 | 
         
            +
                        max_epochs = gr.Number(label="Max epochs (this run)", value=10, precision=0)
         
     | 
| 332 | 
         
            +
                        save_every = gr.Number(label="Save every N epochs", value=5, precision=0)
         
     | 
| 333 | 
         
            +
             
     | 
| 334 | 
         
             
                    run_btn.click(
         
     | 
| 335 | 
         
             
                        fn=run_training,
         
     | 
| 336 | 
         
             
                        inputs=[
         
     | 
| 337 | 
         
             
                            dataset_name, caption, data_root, image_folder,
         
     | 
| 338 | 
         
             
                            c0, c1, c2, c3, c4, c5, c6, c7,
         
     | 
| 339 | 
         
             
                            models_root, output_dir_base, dataset_config,
         
     | 
| 340 | 
         
            +
                            max_epochs, save_every,
         
     | 
| 341 | 
         
             
                        ],
         
     | 
| 342 | 
         
             
                        outputs=logs,
         
     | 
| 343 | 
         
             
                    )
         
     | 
| 
         | 
|
| 345 | 
         
             
                return demo
         
     | 
| 346 | 
         | 
| 347 | 
         | 
| 348 | 
         
            +
            @spaces.GPU(duration=600)
         
     | 
| 349 | 
         
            +
            def _request_gpu_on_startup() -> str:
         
     | 
| 350 | 
         
            +
                return "gpu-requested"
         
     | 
| 351 | 
         
            +
             
     | 
| 352 | 
         
            +
             
     | 
| 353 | 
         
             
            def _startup_download_models() -> None:
         
     | 
| 354 | 
         
             
                models_dir = DEFAULT_MODELS_ROOT
         
     | 
| 355 | 
         
             
                print(f"[QIE] Ensuring models in: {models_dir}")
         
     | 
| 
         | 
|
| 360 | 
         | 
| 361 | 
         | 
| 362 | 
         
             
            if __name__ == "__main__":
         
     | 
| 363 | 
         
            +
                # 0) Request GPU immediately for Spaces dynamic hardware
         
     | 
| 364 | 
         
            +
                try:
         
     | 
| 365 | 
         
            +
                    tag = _request_gpu_on_startup()
         
     | 
| 366 | 
         
            +
                    print(f"[QIE] Spaces GPU tag: {tag}")
         
     | 
| 367 | 
         
            +
                except Exception as e:
         
     | 
| 368 | 
         
            +
                    print(f"[QIE] GPU request skipped or failed: {e}")
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
             
                # 1) Ensure musubi-tuner is cloned before anything else
         
     | 
| 371 | 
         
             
                _startup_clone_musubi_tuner()
         
     | 
| 372 | 
         | 
| 
         | 
|
| 375 | 
         | 
| 376 | 
         
             
                # 3) Launch Gradio app
         
     | 
| 377 | 
         
             
                ui = build_ui()
         
     | 
| 378 | 
         
            +
                # Limit concurrency (training is heavy). Enable queue for Spaces compatibility.
         
     | 
| 379 | 
         
            +
                ui = ui.queue(concurrency_count=1, max_size=16)
         
     | 
| 380 | 
         
             
                ui.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
         
     |