Z-Image-i2L / app.py
kelseye's picture
Update app.py
4a60f19 verified
import spaces
import gradio as gr
import torch
from PIL import Image
import os
import sys
import subprocess
import tempfile
from pathlib import Path
import glob
# Default negative prompts
NEGATIVE_PROMPT_CN = "泛黄,发绿,模糊,低分辨率,低质量图像,扭曲的肢体,诡异的外观,丑陋,AI感,噪点,网格感,JPEG压缩条纹,异常的肢体,水印,乱码,意义不明的字符"
NEGATIVE_PROMPT_EN = "Yellowed, green-tinted, blurry, low-resolution, low-quality image, distorted limbs, eerie appearance, ugly, AI-looking, noise, grid-like artifacts, JPEG compression artifacts, abnormal limbs, watermark, garbled text, meaningless characters"
# Model paths - can be overridden via environment variables
MODELS_DIR = Path(os.environ.get("ZIMAGE_MODELS_DIR", "./models"))
# =============================================================================
# Model Download Functions
# =============================================================================
def download_hf_models(output_dir: Path) -> dict:
"""
Download required models from Hugging Face using huggingface_hub.
Downloads:
- DiffSynth-Studio/Z-Image-i2L
- Tongyi-MAI/Z-Image
- DiffSynth-Studio/General-Image-Encoders
- Tongyi-MAI/Z-Image-Turbo
Returns dict with paths to downloaded models.
"""
from huggingface_hub import snapshot_download
output_dir.mkdir(parents=True, exist_ok=True)
models = [
{
"repo_id": "DiffSynth-Studio/General-Image-Encoders",
"description": "General Image Encoders (SigLIP2-G384, DINOv3-7B)",
"allow_patterns": None,
},
{
"repo_id": "Tongyi-MAI/Z-Image-Turbo",
"description": "Z-Image Turbo (text encoder, VAE, tokenizer)",
"allow_patterns": [
"text_encoder/*.safetensors",
"vae/*.safetensors",
"tokenizer/*",
],
},
{
"repo_id": "Tongyi-MAI/Z-Image",
"description": "Z-Image base model (transformer)",
"allow_patterns": ["transformer/*.safetensors"],
},
{
"repo_id": "DiffSynth-Studio/Z-Image-i2L",
"description": "Z-Image-i2L (Image to LoRA model)",
"allow_patterns": ["*.safetensors"],
},
]
downloaded_paths = {}
for model in models:
repo_id = model["repo_id"]
local_dir = output_dir / repo_id
# Check if already downloaded
if local_dir.exists() and any(local_dir.rglob("*.safetensors")):
print(f" ✓ {repo_id} (already downloaded)")
downloaded_paths[repo_id] = local_dir
continue
print(f" 📥 Downloading {repo_id}...")
print(f" {model['description']}")
try:
result_path = snapshot_download(
repo_id=repo_id,
local_dir=str(local_dir),
allow_patterns=model["allow_patterns"],
local_dir_use_symlinks=False,
resume_download=True,
)
downloaded_paths[repo_id] = Path(result_path)
print(f" ✓ {repo_id}")
except Exception as e:
print(f" ❌ Error downloading {repo_id}: {e}")
raise
return downloaded_paths
def get_model_files(base_path: Path, pattern: str) -> list:
"""Get list of files matching a glob pattern."""
full_pattern = str(base_path / pattern)
files = sorted(glob.glob(full_pattern))
return files
def install_diffsynth_studio():
"""Clone and install DiffSynth-Studio if not already installed."""
try:
from diffsynth.pipelines.z_image import ZImagePipeline
return True, "✅ DiffSynth-Studio is already installed."
except ImportError:
pass
repo_dir = Path(__file__).parent / "DiffSynth-Studio"
try:
if not repo_dir.exists():
print("📥 Cloning DiffSynth-Studio repository...")
subprocess.run(
["git", "clone", "https://github.com/modelscope/DiffSynth-Studio.git", str(repo_dir)],
capture_output=True,
text=True,
check=True
)
print("✅ Repository cloned successfully.")
else:
print("📁 DiffSynth-Studio directory already exists, pulling latest...")
subprocess.run(
["git", "-C", str(repo_dir), "pull"],
capture_output=True,
text=True
)
print("📦 Installing DiffSynth-Studio...")
subprocess.run(
[sys.executable, "-m", "pip", "install", "-e", str(repo_dir)],
capture_output=True,
text=True,
check=True
)
print("✅ DiffSynth-Studio installed successfully.")
sys.path.insert(0, str(repo_dir))
from diffsynth.pipelines.z_image import ZImagePipeline
return True, "✅ DiffSynth-Studio installed successfully!"
except subprocess.CalledProcessError as e:
error_msg = f"❌ Installation failed: {e.stderr}"
print(error_msg)
return False, error_msg
except Exception as e:
error_msg = f"❌ Error during installation: {str(e)}"
print(error_msg)
return False, error_msg
# =============================================================================
# Pipeline Initialization
# =============================================================================
print("=" * 60)
print(" Z-Image-i2L Gradio Demo - Initializing")
print("=" * 60)
print()
# Step 1: Install DiffSynth-Studio
print("🔍 Step 1: Checking DiffSynth-Studio installation...")
success, message = install_diffsynth_studio()
print(message)
if not success:
raise RuntimeError("Failed to install DiffSynth-Studio. Cannot continue.")
# Step 2: Download HuggingFace models
print()
print("🔍 Step 2: Downloading models from HuggingFace...")
print(f" Models directory: {MODELS_DIR.absolute()}")
downloaded_paths = download_hf_models(MODELS_DIR)
# Import required modules
from diffsynth.pipelines.z_image import (
ZImagePipeline, ModelConfig,
ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode
)
from safetensors.torch import save_file, load_file
# Step 3: Configure VRAM settings
print()
print("⚙️ Step 3: Configuring VRAM settings...")
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cuda",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
# Step 4: Resolve local model paths
print()
print("📂 Step 4: Resolving model paths...")
# Z-Image transformer
zimage_path = MODELS_DIR / "Tongyi-MAI" / "Z-Image"
zimage_transformer_files = get_model_files(zimage_path, "transformer/*.safetensors")
# Z-Image-Turbo
zimage_turbo_path = MODELS_DIR / "Tongyi-MAI" / "Z-Image-Turbo"
text_encoder_files = get_model_files(zimage_turbo_path, "text_encoder/*.safetensors")
vae_file = get_model_files(zimage_turbo_path, "vae/diffusion_pytorch_model.safetensors")
tokenizer_path = zimage_turbo_path / "tokenizer"
# General Image Encoders
encoders_path = MODELS_DIR / "DiffSynth-Studio" / "General-Image-Encoders"
siglip_file = get_model_files(encoders_path, "SigLIP2-G384/model.safetensors")
dino_file = get_model_files(encoders_path, "DINOv3-7B/model.safetensors")
# Z-Image-i2L from HuggingFace
zimage_i2l_path = MODELS_DIR / "DiffSynth-Studio" / "Z-Image-i2L"
zimage_i2l_file = get_model_files(zimage_i2l_path, "model.safetensors")
print(f" Z-Image transformer: {len(zimage_transformer_files)} file(s)")
print(f" Text encoder: {len(text_encoder_files)} file(s)")
print(f" VAE: {len(vae_file)} file(s)")
print(f" Tokenizer: {tokenizer_path}")
print(f" SigLIP2: {len(siglip_file)} file(s)")
print(f" DINOv3: {len(dino_file)} file(s)")
print(f" Z-Image-i2L: {len(zimage_i2l_file)} file(s)")
# Validate files
missing = []
if not zimage_transformer_files: missing.append("Z-Image transformer")
if not text_encoder_files: missing.append("Text encoder")
if not vae_file: missing.append("VAE")
if not tokenizer_path.exists(): missing.append("Tokenizer")
if not siglip_file: missing.append("SigLIP2")
if not dino_file: missing.append("DINOv3")
if not zimage_i2l_file: missing.append("Z-Image-i2L")
if missing:
raise FileNotFoundError(f"Missing model files: {', '.join(missing)}")
# Step 5: Load pipeline
print()
print("🚀 Step 5: Loading Z-Image pipeline...")
print(" All models loaded from HuggingFace local paths")
model_configs = [
# All models from HuggingFace - use path= for local files
ModelConfig(path=zimage_transformer_files, **vram_config),
ModelConfig(path=text_encoder_files),
ModelConfig(path=vae_file),
ModelConfig(path=siglip_file),
ModelConfig(path=dino_file),
ModelConfig(path=zimage_i2l_file),
]
pipe = ZImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=model_configs,
tokenizer_config=ModelConfig(path=str(tokenizer_path)),
)
print()
print("✅ Pipeline loaded successfully!")
print("=" * 60)
print()
# =============================================================================
# Gradio Functions
# =============================================================================
@spaces.GPU(duration=120)
def image_to_lora(images, progress=gr.Progress()):
"""Convert input images to a LoRA model."""
if images is None or len(images) == 0:
return None, "❌ Please upload at least one image!"
try:
progress(0.1, desc="Processing images...")
pil_images = []
for img in images:
if isinstance(img, str):
pil_images.append(Image.open(img).convert("RGB"))
elif isinstance(img, tuple):
pil_images.append(Image.open(img[0]).convert("RGB"))
else:
pil_images.append(Image.fromarray(img).convert("RGB"))
progress(0.3, desc="Encoding images to LoRA...")
with torch.no_grad():
embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=pil_images)
progress(0.7, desc="Decoding LoRA weights...")
lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"]
progress(0.9, desc="Saving LoRA file...")
temp_dir = tempfile.mkdtemp()
lora_path = os.path.join(temp_dir, "generated_lora.safetensors")
save_file(lora, lora_path)
progress(1.0, desc="Done!")
return lora_path, f"✅ LoRA generated successfully from {len(pil_images)} image(s)!"
except Exception as e:
return None, f"❌ Error generating LoRA: {str(e)}"
@spaces.GPU(duration=60)
def generate_image(
lora_file,
prompt,
negative_prompt,
seed,
cfg_scale,
sigma_shift,
num_steps,
progress=gr.Progress()
):
"""Generate an image using the created LoRA."""
if lora_file is None:
return None, "❌ Please generate or upload a LoRA file first!"
try:
progress(0.1, desc="Loading LoRA...")
lora = load_file(lora_file)
# Move LoRA tensors to CUDA with correct dtype
lora = {k: v.to(device="cuda", dtype=torch.bfloat16) for k, v in lora.items()}
progress(0.3, desc="Generating image...")
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
seed=int(seed),
cfg_scale=cfg_scale,
num_inference_steps=int(num_steps),
positive_only_lora=lora,
sigma_shift=sigma_shift
)
progress(1.0, desc="Done!")
return image, "✅ Image generated successfully!"
except Exception as e:
return None, f"❌ Error generating image: {str(e)}"
def create_demo():
"""Create the Gradio interface."""
with gr.Blocks(
title="Z-Image-i2L Demo",
theme=gr.themes.Soft(),
css=".gradio-container { max-width: 1200px !important; margin: 0 auto}"
) as demo:
gr.Markdown("""
# 🎨 Z-Image-i2L: Image to LoRA Demo
> 💡 **Tip**: For best results, use 4-6 images with a consistent artistic style.
""")
with gr.Tabs():
with gr.TabItem("📸 Step 1: Image to LoRA"):
with gr.Row():
with gr.Column(scale=1):
input_gallery = gr.Gallery(
label="Upload Style Images (1-6 images)",
file_types=["image"],
columns=3,
height=300,
interactive=True
)
gr.Markdown("""
**Guidelines:**
- Upload 1-6 images with a consistent style
- Higher quality images produce better results
- Mix of subjects helps generalization
""")
generate_lora_btn = gr.Button("🎯 Generate LoRA", variant="primary")
with gr.Column(scale=1):
lora_output = gr.File(
label="Generated LoRA File",
file_types=[".safetensors"],
interactive=False
)
lora_status = gr.Textbox(
label="Status",
interactive=False,
lines=2
)
with gr.TabItem("🖼️ Step 2: Generate Images"):
with gr.Row():
with gr.Column(scale=1):
lora_input = gr.File(
label="LoRA File (from Step 1 or upload)",
file_types=[".safetensors"]
)
prompt = gr.Textbox(
label="Prompt",
placeholder="Describe what you want to generate...",
value="a cat",
lines=2
)
with gr.Accordion("Negative Prompt", open=False):
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=NEGATIVE_PROMPT_CN,
lines=3
)
with gr.Row():
use_cn_neg = gr.Button("Use Chinese", size="sm")
use_en_neg = gr.Button("Use English", size="sm")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Number(label="Seed", value=0, precision=0)
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=10, value=4, step=0.5)
sigma_shift = gr.Slider(label="Sigma Shift", minimum=1, maximum=15, value=8, step=1)
num_steps = gr.Slider(label="Steps", minimum=20, maximum=100, value=50, step=5)
generate_btn = gr.Button("✨ Generate Image", variant="primary")
with gr.Column(scale=1):
output_image = gr.Image(label="Generated Image", type="pil", height=512)
gen_status = gr.Textbox(label="Status", interactive=False, lines=2)
gr.Markdown("""
---
**Resources:** [Z-Image-i2L (HuggingFace)](https://huggingface.co/DiffSynth-Studio/Z-Image-i2L) |
[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) |
**Settings:** CFG=4, Sigma Shift=8, Steps=50
""")
# Event handlers
generate_lora_btn.click(
fn=image_to_lora,
inputs=[input_gallery],
outputs=[lora_output, lora_status]
)
lora_output.change(fn=lambda x: x, inputs=[lora_output], outputs=[lora_input])
generate_btn.click(
fn=generate_image,
inputs=[lora_input, prompt, negative_prompt, seed, cfg_scale, sigma_shift, num_steps],
outputs=[output_image, gen_status]
)
use_cn_neg.click(fn=lambda: NEGATIVE_PROMPT_CN, outputs=[negative_prompt])
use_en_neg.click(fn=lambda: NEGATIVE_PROMPT_EN, outputs=[negative_prompt])
return demo
if __name__ == "__main__":
print("Starting Gradio server...")
demo = create_demo()
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)