Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import spaces | |
| import gradio as gr | |
| import sys | |
| import platform | |
| import diffusers | |
| import transformers | |
| import psutil | |
| import os | |
| import time | |
| import traceback | |
| from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig | |
| from diffusers import ZImagePipeline, AutoModel | |
| from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig | |
| latent_history = [] | |
| # ============================================================ | |
| # LOGGING BUFFER | |
| # ============================================================ | |
| LOGS = "" | |
| def log(msg): | |
| global LOGS | |
| print(msg) | |
| LOGS += msg + "\n" | |
| return msg | |
| # ============================================================ | |
| # SYSTEM METRICS — LIVE GPU + CPU MONITORING | |
| # ============================================================ | |
| def log_system_stats(tag=""): | |
| try: | |
| log(f"\n===== 🔥 SYSTEM STATS {tag} =====") | |
| # ============= GPU STATS ============= | |
| if torch.cuda.is_available(): | |
| allocated = torch.cuda.memory_allocated(0) / 1e9 | |
| reserved = torch.cuda.memory_reserved(0) / 1e9 | |
| total = torch.cuda.get_device_properties(0).total_memory / 1e9 | |
| free = total - allocated | |
| log(f"💠 GPU Total : {total:.2f} GB") | |
| log(f"💠 GPU Allocated : {allocated:.2f} GB") | |
| log(f"💠 GPU Reserved : {reserved:.2f} GB") | |
| log(f"💠 GPU Free : {free:.2f} GB") | |
| # ============= CPU STATS ============ | |
| cpu = psutil.cpu_percent() | |
| ram_used = psutil.virtual_memory().used / 1e9 | |
| ram_total = psutil.virtual_memory().total / 1e9 | |
| log(f"🧠 CPU Usage : {cpu}%") | |
| log(f"🧠 RAM Used : {ram_used:.2f} GB / {ram_total:.2f} GB") | |
| except Exception as e: | |
| log(f"⚠️ Failed to log system stats: {e}") | |
| # ============================================================ | |
| # ENVIRONMENT INFO | |
| # ============================================================ | |
| log("===================================================") | |
| log("🔍 Z-IMAGE-TURBO DEBUGGING + LIVE METRIC LOGGER") | |
| log("===================================================\n") | |
| log(f"📌 PYTHON VERSION : {sys.version.replace(chr(10),' ')}") | |
| log(f"📌 PLATFORM : {platform.platform()}") | |
| log(f"📌 TORCH VERSION : {torch.__version__}") | |
| log(f"📌 TRANSFORMERS VERSION : {transformers.__version__}") | |
| log(f"📌 DIFFUSERS VERSION : {diffusers.__version__}") | |
| log(f"📌 CUDA AVAILABLE : {torch.cuda.is_available()}") | |
| log_system_stats("AT STARTUP") | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("❌ CUDA Required") | |
| device = "cuda" | |
| gpu_id = 0 | |
| # ============================================================ | |
| # MODEL SETTINGS | |
| # ============================================================ | |
| model_cache = "./weights/" | |
| model_id = "Tongyi-MAI/Z-Image-Turbo" | |
| torch_dtype = torch.bfloat16 | |
| USE_CPU_OFFLOAD = False | |
| log("\n===================================================") | |
| log("🧠 MODEL CONFIGURATION") | |
| log("===================================================") | |
| log(f"Model ID : {model_id}") | |
| log(f"Model Cache Directory : {model_cache}") | |
| log(f"torch_dtype : {torch_dtype}") | |
| log(f"USE_CPU_OFFLOAD : {USE_CPU_OFFLOAD}") | |
| log_system_stats("BEFORE TRANSFORMER LOAD") | |
| # ============================================================ | |
| # FUNCTION TO CONVERT LATENTS TO IMAGE | |
| # ============================================================ | |
| def latent_to_image(latent): | |
| """ | |
| Convert a latent tensor to a PIL image using pipe.vae | |
| """ | |
| try: | |
| img_tensor = pipe.vae.decode(latent) | |
| img_tensor = (img_tensor / 2 + 0.5).clamp(0, 1) | |
| pil_img = T.ToPILImage()(img_tensor[0].cpu()) # <--- single image | |
| return pil_img | |
| except Exception as e: | |
| log(f"⚠️ Failed to decode latent: {e}") | |
| # fallback blank image | |
| return Image.new("RGB", (latent.shape[-1]*8, latent.shape[-2]*8), color=(255,255,255)) | |
| # ============================================================ | |
| # SAFE TRANSFORMER INSPECTION | |
| # ============================================================ | |
| def inspect_transformer(model, name): | |
| log(f"\n🔍🔍 FULL TRANSFORMER DEBUG DUMP: {name}") | |
| log("=" * 80) | |
| try: | |
| log(f"Model class : {model.__class__.__name__}") | |
| log(f"DType : {getattr(model, 'dtype', 'unknown')}") | |
| log(f"Device : {next(model.parameters()).device}") | |
| log(f"Requires Grad? : {any(p.requires_grad for p in model.parameters())}") | |
| # Check quantization | |
| if hasattr(model, "is_loaded_in_4bit"): | |
| log(f"4bit Quantization : {model.is_loaded_in_4bit}") | |
| if hasattr(model, "is_loaded_in_8bit"): | |
| log(f"8bit Quantization : {model.is_loaded_in_8bit}") | |
| # Find blocks | |
| candidates = ["transformer_blocks", "blocks", "layers", "encoder", "model"] | |
| blocks = None | |
| chosen_attr = None | |
| for attr in candidates: | |
| if hasattr(model, attr): | |
| blocks = getattr(model, attr) | |
| chosen_attr = attr | |
| break | |
| log(f"Block container attr : {chosen_attr}") | |
| if blocks is None: | |
| log("⚠️ No valid block container found.") | |
| return | |
| if not hasattr(blocks, "__len__"): | |
| log("⚠️ Blocks exist but not iterable.") | |
| return | |
| total = len(blocks) | |
| log(f"Total Blocks : {total}") | |
| log("-" * 80) | |
| # Inspect first N blocks | |
| N = min(20, total) | |
| for i in range(N): | |
| block = blocks[i] | |
| log(f"\n🧩 Block [{i}/{total-1}]") | |
| log(f"Class: {block.__class__.__name__}") | |
| # Print submodules | |
| for n, m in block.named_children(): | |
| log(f" ├─ {n}: {m.__class__.__name__}") | |
| # Print attention related | |
| if hasattr(block, "attn"): | |
| attn = block.attn | |
| log(f" ├─ Attention: {attn.__class__.__name__}") | |
| log(f" │ Heads : {getattr(attn, 'num_heads', 'unknown')}") | |
| log(f" │ Dim : {getattr(attn, 'hidden_size', 'unknown')}") | |
| log(f" │ Backend : {getattr(attn, 'attention_backend', 'unknown')}") | |
| # Device + dtype info | |
| try: | |
| dev = next(block.parameters()).device | |
| log(f" ├─ Device : {dev}") | |
| except StopIteration: | |
| pass | |
| try: | |
| dt = next(block.parameters()).dtype | |
| log(f" ├─ DType : {dt}") | |
| except StopIteration: | |
| pass | |
| log("\n🔚 END TRANSFORMER DEBUG DUMP") | |
| log("=" * 80) | |
| except Exception as e: | |
| log(f"❌ ERROR IN INSPECTOR: {e}") | |
| import torch | |
| import time | |
| # ---------- UTILITY ---------- | |
| def pretty_header(title): | |
| log("\n\n" + "=" * 80) | |
| log(f"🎛️ {title}") | |
| log("=" * 80 + "\n") | |
| # ---------- MEMORY ---------- | |
| def get_vram(prefix=""): | |
| try: | |
| allocated = torch.cuda.memory_allocated() / 1024**2 | |
| reserved = torch.cuda.memory_reserved() / 1024**2 | |
| log(f"{prefix}Allocated VRAM : {allocated:.2f} MB") | |
| log(f"{prefix}Reserved VRAM : {reserved:.2f} MB") | |
| except: | |
| log(f"{prefix}VRAM: CUDA not available") | |
| # ---------- MODULE INSPECT ---------- | |
| def inspect_module(name, module): | |
| pretty_header(f"🔬 Inspecting {name}") | |
| try: | |
| log(f"📦 Class : {module.__class__.__name__}") | |
| log(f"🔢 DType : {getattr(module, 'dtype', 'unknown')}") | |
| log(f"💻 Device : {next(module.parameters()).device}") | |
| log(f"🧮 Params : {sum(p.numel() for p in module.parameters()):,}") | |
| # Quantization state | |
| if hasattr(module, "is_loaded_in_4bit"): | |
| log(f"⚙️ 4-bit QLoRA : {module.is_loaded_in_4bit}") | |
| if hasattr(module, "is_loaded_in_8bit"): | |
| log(f"⚙️ 8-bit load : {module.is_loaded_in_8bit}") | |
| # Attention backend (DiT) | |
| if hasattr(module, "set_attention_backend"): | |
| try: | |
| attn = getattr(module, "attention_backend", None) | |
| log(f"🚀 Attention Backend: {attn}") | |
| except: | |
| pass | |
| # Search for blocks | |
| candidates = ["transformer_blocks", "blocks", "layers", "encoder", "model"] | |
| blocks = None | |
| chosen_attr = None | |
| for attr in candidates: | |
| if hasattr(module, attr): | |
| blocks = getattr(module, attr) | |
| chosen_attr = attr | |
| break | |
| log(f"\n📚 Block Container : {chosen_attr}") | |
| if blocks is None: | |
| log("⚠️ No block structure found") | |
| return | |
| if not hasattr(blocks, "__len__"): | |
| log("⚠️ Blocks exist but are not iterable") | |
| return | |
| total = len(blocks) | |
| log(f"🔢 Total Blocks : {total}\n") | |
| # Inspect first 15 blocks | |
| N = min(15, total) | |
| for i in range(N): | |
| blk = blocks[i] | |
| log(f"\n🧩 Block [{i}/{total-1}] — {blk.__class__.__name__}") | |
| for n, m in blk.named_children(): | |
| log(f" ├─ {n:<15} {m.__class__.__name__}") | |
| # Attention details | |
| if hasattr(blk, "attn"): | |
| a = blk.attn | |
| log(f" ├─ Attention") | |
| log(f" │ Heads : {getattr(a, 'num_heads', 'unknown')}") | |
| log(f" │ Dim : {getattr(a, 'hidden_size', 'unknown')}") | |
| log(f" │ Backend : {getattr(a, 'attention_backend', 'unknown')}") | |
| # Device / dtype | |
| try: | |
| log(f" ├─ Device : {next(blk.parameters()).device}") | |
| log(f" ├─ DType : {next(blk.parameters()).dtype}") | |
| except StopIteration: | |
| pass | |
| get_vram(" ▶ ") | |
| except Exception as e: | |
| log(f"❌ Module inspect error: {e}") | |
| # ---------- LORA INSPECTION ---------- | |
| def inspect_loras(pipe): | |
| pretty_header("🧩 LoRA ADAPTERS") | |
| try: | |
| if not hasattr(pipe, "lora_state_dict") and not hasattr(pipe, "adapter_names"): | |
| log("⚠️ No LoRA system detected.") | |
| return | |
| if hasattr(pipe, "adapter_names"): | |
| names = pipe.adapter_names | |
| log(f"Available Adapters: {names}") | |
| if hasattr(pipe, "active_adapters"): | |
| log(f"Active Adapters : {pipe.active_adapters}") | |
| if hasattr(pipe, "lora_scale"): | |
| log(f"LoRA Scale : {pipe.lora_scale}") | |
| # LoRA modules | |
| if hasattr(pipe, "transformer") and hasattr(pipe.transformer, "modules"): | |
| for name, module in pipe.transformer.named_modules(): | |
| if "lora" in name.lower(): | |
| log(f" 🔧 LoRA Module: {name} ({module.__class__.__name__})") | |
| except Exception as e: | |
| log(f"❌ LoRA inspect error: {e}") | |
| # ---------- PIPELINE INSPECTOR ---------- | |
| def debug_pipeline(pipe): | |
| pretty_header("🚀 FULL PIPELINE DEBUGGING") | |
| try: | |
| log(f"Pipeline Class : {pipe.__class__.__name__}") | |
| log(f"Attention Impl : {getattr(pipe, 'attn_implementation', 'unknown')}") | |
| log(f"Device : {pipe.device}") | |
| except: | |
| pass | |
| get_vram("▶ ") | |
| # Inspect TRANSFORMER | |
| if hasattr(pipe, "transformer"): | |
| inspect_module("Transformer", pipe.transformer) | |
| # Inspect TEXT ENCODER | |
| if hasattr(pipe, "text_encoder") and pipe.text_encoder is not None: | |
| inspect_module("Text Encoder", pipe.text_encoder) | |
| # Inspect UNET (if ZImage pipeline has it) | |
| if hasattr(pipe, "unet"): | |
| inspect_module("UNet", pipe.unet) | |
| # LoRA adapters | |
| inspect_loras(pipe) | |
| pretty_header("🎉 END DEBUG REPORT") | |
| # ============================================================ | |
| # LOAD TRANSFORMER — WITH LIVE STATS | |
| # ============================================================ | |
| log("\n===================================================") | |
| log("🔧 LOADING TRANSFORMER BLOCK") | |
| log("===================================================") | |
| log("📌 Logging memory before load:") | |
| log_system_stats("START TRANSFORMER LOAD") | |
| try: | |
| quant_cfg = DiffusersBitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch_dtype, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| transformer = AutoModel.from_pretrained( | |
| model_id, | |
| cache_dir=model_cache, | |
| subfolder="transformer", | |
| quantization_config=quant_cfg, | |
| torch_dtype=torch_dtype, | |
| device_map=device, | |
| ) | |
| log("✅ Transformer loaded successfully.") | |
| except Exception as e: | |
| log(f"❌ Transformer load failed: {e}") | |
| transformer = None | |
| log_system_stats("AFTER TRANSFORMER LOAD") | |
| if transformer: | |
| inspect_transformer(transformer, "Transformer") | |
| # ============================================================ | |
| # LOAD TEXT ENCODER | |
| # ============================================================ | |
| log("\n===================================================") | |
| log("🔧 LOADING TEXT ENCODER") | |
| log("===================================================") | |
| log_system_stats("START TEXT ENCODER LOAD") | |
| try: | |
| quant_cfg2 = TransformersBitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch_dtype, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| text_encoder = AutoModel.from_pretrained( | |
| model_id, | |
| cache_dir=model_cache, | |
| subfolder="text_encoder", | |
| quantization_config=quant_cfg2, | |
| torch_dtype=torch_dtype, | |
| device_map=device, | |
| ) | |
| log("✅ Text encoder loaded successfully.") | |
| except Exception as e: | |
| log(f"❌ Text encoder load failed: {e}") | |
| text_encoder = None | |
| log_system_stats("AFTER TEXT ENCODER LOAD") | |
| if text_encoder: | |
| inspect_transformer(text_encoder, "Text Encoder") | |
| # ============================================================ | |
| # BUILD PIPELINE | |
| # ============================================================ | |
| log("\n===================================================") | |
| log("🔧 BUILDING PIPELINE") | |
| log("===================================================") | |
| log_system_stats("START PIPELINE BUILD") | |
| try: | |
| pipe = ZImagePipeline.from_pretrained( | |
| model_id, | |
| transformer=transformer, | |
| text_encoder=text_encoder, | |
| torch_dtype=torch_dtype, | |
| ) | |
| # If transformer supports setting backend, prefer flash-3 | |
| try: | |
| if hasattr(pipe, "transformer") and hasattr(pipe.transformer, "set_attention_backend"): | |
| pipe.transformer.set_attention_backend("_flash_3") | |
| log("✅ transformer.set_attention_backend('_flash_3') called") | |
| except Exception as _e: | |
| log(f"⚠️ set_attention_backend failed: {_e}") | |
| # default LoRA load (keeps your existing behaviour) | |
| try: | |
| pipe.load_lora_weights("rahul7star/ZImageLora", | |
| weight_name="NSFW/doggystyle_pov.safetensors", adapter_name="lora") | |
| pipe.set_adapters(["lora",], adapter_weights=[1.]) | |
| pipe.fuse_lora(adapter_names=["lora"], lora_scale=0.75) | |
| except Exception as _e: | |
| log(f"⚠️ Default LoRA load failed: {_e}") | |
| debug_pipeline(pipe) | |
| # pipe.unload_lora_weights() | |
| pipe.to("cuda") | |
| log("✅ Pipeline built successfully.") | |
| LOGS += log("Pipeline build completed.") + "\n" | |
| except Exception as e: | |
| log(f"❌ Pipeline build failed: {e}") | |
| log(traceback.format_exc()) | |
| pipe = None | |
| log_system_stats("AFTER PIPELINE BUILD") | |
| # ----------------------------- | |
| # Monkey-patch prepare_latents (safe) | |
| # ----------------------------- | |
| if pipe is not None and hasattr(pipe, "prepare_latents"): | |
| original_prepare_latents = pipe.prepare_latents | |
| def logged_prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): | |
| try: | |
| result_latents = original_prepare_latents(batch_size, num_channels_latents, height, width, dtype, device, generator, latents) | |
| log_msg = f"🔹 prepare_latents called | shape={result_latents.shape}, dtype={result_latents.dtype}, device={result_latents.device}" | |
| if hasattr(self, "_latents_log"): | |
| self._latents_log.append(log_msg) | |
| else: | |
| self._latents_log = [log_msg] | |
| return result_latents | |
| except Exception as e: | |
| log(f"⚠️ prepare_latents wrapper failed: {e}") | |
| raise | |
| # apply patch safely | |
| try: | |
| pipe.prepare_latents = logged_prepare_latents.__get__(pipe) | |
| log("✅ prepare_latents monkey-patched") | |
| except Exception as e: | |
| log(f"⚠️ Failed to attach prepare_latents patch: {e}") | |
| else: | |
| log("❌ WARNING: Pipe not initialized or prepare_latents missing; skipping prepare_latents patch") | |
| from PIL import Image | |
| import torch | |
| # -------------------------- | |
| # Helper: Safe latent extractor | |
| # -------------------------- | |
| def safe_get_latents(pipe, height, width, generator, device, LOGS): | |
| """ | |
| Safely prepare latents for any ZImagePipeline variant. | |
| Returns latents tensor, logs issues instead of failing. | |
| """ | |
| try: | |
| # Determine number of channels | |
| num_channels = 4 # default fallback | |
| if hasattr(pipe, "unet") and hasattr(pipe.unet, "in_channels"): | |
| num_channels = pipe.unet.in_channels | |
| elif hasattr(pipe, "vae") and hasattr(pipe.vae, "latent_channels"): | |
| num_channels = pipe.vae.latent_channels # some pipelines define this | |
| LOGS.append(f"🔹 Using num_channels={num_channels} for latents") | |
| latents = pipe.prepare_latents( | |
| batch_size=1, | |
| num_channels_latents=num_channels, | |
| height=height, | |
| width=width, | |
| dtype=torch.float32, | |
| device=device, | |
| generator=generator, | |
| ) | |
| LOGS.append(f"🔹 Latents shape: {latents.shape}, dtype: {latents.dtype}, device: {latents.device}") | |
| return latents | |
| except Exception as e: | |
| LOGS.append(f"⚠️ Latent extraction failed: {e}") | |
| # fallback: guess a safe shape | |
| fallback_channels = 16 # try standard default for ZImage pipelines | |
| latents = torch.randn((1, fallback_channels, height // 8, width // 8), | |
| generator=generator, device=device) | |
| LOGS.append(f"🔹 Using fallback random latents shape: {latents.shape}") | |
| return latents | |
| # -------------------------- | |
| # Main generation function (kept exactly as your logic) | |
| # -------------------------- | |
| def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0): | |
| LOGS = [] | |
| device = "cuda" | |
| generator = torch.Generator(device).manual_seed(int(seed)) | |
| # placeholders | |
| placeholder = Image.new("RGB", (width, height), color=(255, 255, 255)) | |
| latent_gallery = [] | |
| final_gallery = [] | |
| try: | |
| # --- Try advanced latent mode --- | |
| try: | |
| latents = safe_get_latents(pipe, height, width, generator, device, LOGS) | |
| for i, t in enumerate(pipe.scheduler.timesteps): | |
| # Step-wise denoising | |
| with torch.no_grad(): | |
| noise_pred = pipe.unet(latents, t, encoder_hidden_states=pipe.get_text_embeddings(prompt))["sample"] | |
| latents = pipe.scheduler.step(noise_pred, t, latents)["prev_sample"] | |
| # Convert latent to preview image | |
| try: | |
| latent_img = latent_to_image(latents) # returns single PIL image | |
| except Exception: | |
| latent_img = placeholder | |
| latent_gallery.append(latent_img) | |
| # Yield intermediate update: final gallery empty for now | |
| yield None, latent_gallery, final_gallery, LOGS | |
| # decode final image after all timesteps | |
| final_img = pipe.decode_latents(latents)[0] | |
| final_gallery.append(final_img) | |
| LOGS.append("✅ Advanced latent pipeline succeeded.") | |
| yield final_img, latent_gallery, final_gallery, LOGS | |
| except Exception as e: | |
| LOGS.append(f"⚠️ Advanced latent mode failed: {e}") | |
| LOGS.append("🔁 Switching to standard pipeline...") | |
| # Standard pipeline fallback | |
| try: | |
| output = pipe( | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| ) | |
| final_img = output.images[0] | |
| final_gallery.append(final_img) | |
| latent_gallery.append(final_img) # optionally show in latent gallery as last step | |
| LOGS.append("✅ Standard pipeline succeeded.") | |
| yield final_img, latent_gallery, final_gallery, LOGS | |
| except Exception as e2: | |
| LOGS.append(f"❌ Standard pipeline failed: {e2}") | |
| final_gallery.append(placeholder) | |
| latent_gallery.append(placeholder) | |
| yield placeholder, latent_gallery, final_gallery, LOGS | |
| except Exception as e: | |
| LOGS.append(f"❌ Total failure: {e}") | |
| final_gallery.append(placeholder) | |
| latent_gallery.append(placeholder) | |
| yield placeholder, latent_gallery, final_gallery, LOGS | |
| def generate_image_backup(prompt, height, width, steps, seed, guidance_scale=0.0, return_latents=False): | |
| """ | |
| Robust dual pipeline: | |
| - Advanced latent generation first | |
| - Fallback to standard pipeline if latent fails | |
| - Always returns final image | |
| - Returns gallery (latents or final image) and logs | |
| """ | |
| LOGS = [] | |
| image = None | |
| latents = None | |
| gallery = [] | |
| # Keep a placeholder original image (white) in case everything fails | |
| original_image = Image.new("RGB", (width, height), color=(255, 255, 255)) | |
| try: | |
| generator = torch.Generator(device).manual_seed(int(seed)) | |
| # ------------------------------- | |
| # Try advanced latent generation | |
| # ------------------------------- | |
| try: | |
| batch_size = 1 | |
| num_channels_latents = getattr(pipe.unet, "in_channels", None) | |
| if num_channels_latents is None: | |
| raise AttributeError("pipe.unet.in_channels not found, fallback to standard pipeline") | |
| latents = pipe.prepare_latents( | |
| batch_size=batch_size, | |
| num_channels=num_channels_latents, | |
| height=height, | |
| width=width, | |
| dtype=torch.float32, | |
| device=device, | |
| generator=generator | |
| ) | |
| LOGS.append(f"✅ Latents prepared: {latents.shape}") | |
| output = pipe( | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| latents=latents | |
| ) | |
| image = output.images[0] | |
| gallery = [image] if image else [] | |
| LOGS.append("✅ Advanced latent generation succeeded.") | |
| # ------------------------------- | |
| # Fallback to standard pipeline | |
| # ------------------------------- | |
| except Exception as e_latent: | |
| LOGS.append(f"⚠️ Advanced latent generation failed: {e_latent}") | |
| LOGS.append("🔁 Falling back to standard pipeline...") | |
| try: | |
| output = pipe( | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance_scale, | |
| generator=generator | |
| ) | |
| image = output.images[0] | |
| gallery = [image] if image else [] | |
| LOGS.append("✅ Standard pipeline generation succeeded.") | |
| except Exception as e_standard: | |
| LOGS.append(f"❌ Standard pipeline generation failed: {e_standard}") | |
| image = original_image # Always return some image | |
| gallery = [image] | |
| # ------------------------------- | |
| # Return all 3 outputs | |
| # ------------------------------- | |
| return image, gallery, LOGS | |
| except Exception as e: | |
| LOGS.append(f"❌ Inference failed entirely: {e}") | |
| return original_image, [original_image], LOGS | |
| # ============================================================ | |
| # UI | |
| # ============================================================ | |
| # Utility: scan local HF cache for safetensors in a repo folder name | |
| def list_loras_from_repo(repo_id): | |
| """ | |
| Attempts to find safetensors inside HF cache directory for repo_id. | |
| This only scans local cache; it does NOT download anything. | |
| Returns: | |
| A list of strings suitable for showing in the dropdown. Prefer returning | |
| paths relative to the repo root (e.g. "NSFW/doggystyle_pov.safetensors") so that | |
| pipe.load_lora_weights(repo_id, weight_name=that_path) works for nested files. | |
| If a relative path can't be determined, returns absolute cached file paths. | |
| """ | |
| if not repo_id: | |
| return [] | |
| safe_list = [] | |
| # Candidate cache roots | |
| hf_cache = os.path.expanduser("~/.cache/huggingface/hub") | |
| alt_cache = "/home/user/.cache/huggingface/hub" | |
| candidates = [hf_cache, alt_cache] | |
| # Normalize repo variants to search for in path | |
| owner_repo = repo_id.replace("/", "_") | |
| owner_repo_dash = repo_id.replace("/", "-") | |
| owner_repo_double = repo_id.replace("/", "--") | |
| # Walk caches and collect safetensors | |
| for root_cache in candidates: | |
| if not os.path.exists(root_cache): | |
| continue | |
| for dirpath, dirnames, filenames in os.walk(root_cache): | |
| for f in filenames: | |
| if not f.endswith(".safetensors"): | |
| continue | |
| full_path = os.path.join(dirpath, f) | |
| # try to find a repo-root-like substring in dirpath | |
| chosen_base = None | |
| for pattern in (owner_repo_double, owner_repo_dash, owner_repo): | |
| idx = dirpath.find(pattern) | |
| if idx != -1: | |
| chosen_base = dirpath[: idx + len(pattern)] | |
| break | |
| # fallback: look for the repo folder name (last component) e.g., "ZImageLora" | |
| if chosen_base is None: | |
| repo_tail = repo_id.split("/")[-1] | |
| idx2 = dirpath.find(repo_tail) | |
| if idx2 != -1: | |
| chosen_base = dirpath[: idx2 + len(repo_tail)] | |
| # If we found a base that looks like the cached repo root, compute relative path | |
| if chosen_base: | |
| try: | |
| rel = os.path.relpath(full_path, chosen_base) | |
| # If relpath goes up (starts with ..) then prefer full_path | |
| if rel and not rel.startswith(".."): | |
| # Normalize to forward slashes for HF repo weight_name usage | |
| rel_normalized = rel.replace(os.sep, "/") | |
| safe_list.append(rel_normalized) | |
| continue | |
| except Exception: | |
| pass | |
| # Otherwise append absolute path (last resort) | |
| safe_list.append(full_path) | |
| # remove duplicates and sort | |
| safe_list = sorted(list(dict.fromkeys(safe_list))) | |
| return safe_list | |
| with gr.Blocks(title="Z-Image-Turbo") as demo: | |
| with gr.Tabs(): | |
| with gr.TabItem("Image & Latents"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox(label="Prompt", value="boat in Ocean") | |
| height = gr.Slider(256, 2048, value=1024, step=8, label="Height") | |
| width = gr.Slider(256, 2048, value=1024, step=8, label="Width") | |
| steps = gr.Slider(1, 50, value=20, step=1, label="Inference Steps") | |
| seed = gr.Number(value=42, label="Seed") | |
| run_btn = gr.Button("Generate Image") | |
| with gr.Column(scale=1): | |
| final_image = gr.Image(label="Final Image") | |
| latent_gallery = gr.Gallery( | |
| label="Latent Steps", columns=4, height=256, preview=True | |
| ) | |
| with gr.TabItem("Logs"): | |
| logs_box = gr.Textbox(label="All Logs", lines=25) | |
| # New UI: LoRA repo textbox, dropdown, refresh & rebuild | |
| with gr.Row(): | |
| lora_repo = gr.Textbox(label="LoRA Repo (HF id)", value="rahul7star/ZImageLora", placeholder="e.g. rahul7star/ZImageLora") | |
| lora_dropdown = gr.Dropdown(choices=[], label="LoRA files (from local cache)") | |
| refresh_lora_btn = gr.Button("Refresh LoRA List") | |
| rebuild_pipe_btn = gr.Button("Rebuild pipeline (use selected LoRA)") | |
| # Refresh callback: repopulate dropdown from repo text | |
| def refresh_lora_list(repo_name): | |
| try: | |
| files = list_loras_from_repo(repo_name) | |
| if not files: | |
| return gr.update(choices=[], value=None) | |
| return gr.update(choices=files, value=files[0]) | |
| except Exception as e: | |
| log(f"⚠️ refresh_lora_list failed: {e}") | |
| return gr.update(choices=[], value=None) | |
| refresh_lora_btn.click(refresh_lora_list, inputs=[lora_repo], outputs=[lora_dropdown]) | |
| # Rebuild callback: build pipeline with selected lora file path (if any) | |
| def rebuild_pipeline_with_lora(lora_path, repo_name): | |
| global pipe, LOGS | |
| try: | |
| log(f"🔄 Rebuilding pipeline using LoRA repo={repo_name} file={lora_path}") | |
| # call existing logic to rebuild: attempt to create new pipeline then load lora file | |
| pipe = ZImagePipeline.from_pretrained( | |
| model_id, | |
| transformer=transformer, | |
| text_encoder=text_encoder, | |
| torch_dtype=torch_dtype, | |
| ) | |
| # try set backend | |
| try: | |
| if hasattr(pipe, "transformer") and hasattr(pipe.transformer, "set_attention_backend"): | |
| pipe.transformer.set_attention_backend("_flash_3") | |
| except Exception as _e: | |
| log(f"⚠️ set_attention_backend failed during rebuild: {_e}") | |
| # load selected lora if provided | |
| if lora_path: | |
| weight_name_to_use = None | |
| # If dropdown provided a relative-style path (contains a slash or no leading /), | |
| # use it directly as weight_name (HF expects "path/inside/repo.safetensors") | |
| if ("/" in lora_path) and not os.path.isabs(lora_path): | |
| weight_name_to_use = lora_path | |
| else: | |
| # It might be an absolute path in cache; try to compute relative path to repo cache root | |
| abs_path = lora_path if os.path.isabs(lora_path) else None | |
| if abs_path and os.path.exists(abs_path): | |
| # attempt to find repo-root-ish substring in abs_path | |
| repo_variants = [ | |
| repo_name.replace("/", "--"), | |
| repo_name.replace("/", "-"), | |
| repo_name.replace("/", "_"), | |
| repo_name.split("/")[-1], | |
| ] | |
| chosen_base = None | |
| for v in repo_variants: | |
| idx = abs_path.find(v) | |
| if idx != -1: | |
| chosen_base = abs_path[: idx + len(v)] | |
| break | |
| if chosen_base: | |
| try: | |
| rel = os.path.relpath(abs_path, chosen_base) | |
| if rel and not rel.startswith(".."): | |
| weight_name_to_use = rel.replace(os.sep, "/") | |
| except Exception: | |
| weight_name_to_use = None | |
| # fallback to basename | |
| if weight_name_to_use is None: | |
| weight_name_to_use = os.path.basename(lora_path) | |
| # Now attempt to load | |
| try: | |
| pipe.load_lora_weights(repo_name or "rahul7star/ZImageLora", | |
| weight_name=weight_name_to_use, | |
| adapter_name="lora") | |
| pipe.set_adapters(["lora"], adapter_weights=[1.]) | |
| pipe.fuse_lora(adapter_names=["lora"], lora_scale=0.75) | |
| log(f"✅ Loaded LoRA weight: {weight_name_to_use} from repo {repo_name}") | |
| except Exception as _e: | |
| log(f"⚠️ Failed to load selected LoRA during rebuild using weight_name='{weight_name_to_use}': {_e}") | |
| # as last resort, try loading using basename | |
| try: | |
| fallback_name = os.path.basename(lora_path) | |
| pipe.load_lora_weights(repo_name or "rahul7star/ZImageLora", | |
| weight_name=fallback_name, | |
| adapter_name="lora") | |
| pipe.set_adapters(["lora"], adapter_weights=[1.]) | |
| pipe.fuse_lora(adapter_names=["lora"], lora_scale=0.75) | |
| log(f"✅ Fallback loaded LoRA weight basename: {fallback_name}") | |
| except Exception as _e2: | |
| log(f"❌ Fallback LoRA load also failed: {_e2}") | |
| # finalize | |
| debug_pipeline(pipe) | |
| pipe.to("cuda") | |
| # re-attach monkey patch safely | |
| if pipe is not None and hasattr(pipe, "prepare_latents"): | |
| try: | |
| original_prepare = pipe.prepare_latents | |
| def logged_prepare(self, *args, **kwargs): | |
| lat = original_prepare(*args, **kwargs) | |
| msg = f"🔹 prepare_latents called | shape={lat.shape}, dtype={lat.dtype}" | |
| if hasattr(self, "_latents_log"): | |
| self._latents_log.append(msg) | |
| else: | |
| self._latents_log = [msg] | |
| return lat | |
| pipe.prepare_latents = logged_prepare.__get__(pipe) | |
| log("✅ Re-applied prepare_latents monkey patch after rebuild") | |
| except Exception as _e: | |
| log(f"⚠️ Could not re-apply prepare_latents patch: {_e}") | |
| return "\n".join([LOGS, "Rebuild complete."]) | |
| except Exception as e: | |
| log(f"❌ Rebuild pipeline failed: {e}") | |
| log(traceback.format_exc()) | |
| return "\n".join([LOGS, f"Rebuild failed: {e}"]) | |
| rebuild_pipe_btn.click(rebuild_pipeline_with_lora, inputs=[lora_dropdown, lora_repo], outputs=[logs_box]) | |
| # Wire the button AFTER all components exist | |
| run_btn.click( | |
| generate_image, | |
| inputs=[prompt, height, width, steps, seed], | |
| outputs=[final_image, latent_gallery, logs_box] | |
| ) | |
| demo.launch() |