stablediff / app.py
oldmonk69's picture
Update app.py
b8152ef verified
import os, json
from typing import List, Dict, Any, Optional
from PIL import Image
import torch
import gradio as gr
import spaces
from huggingface_hub import snapshot_download
from diffusers import (
StableDiffusionXLPipeline,
StableDiffusionPipeline,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
DDIMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
# Config (set in Space Secrets if private)
MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "DB2169/CyberPony_Lora").strip()
CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", "SAFETENSORS_FILENAME.safetensors").strip()
HF_TOKEN = os.getenv("HF_TOKEN", None)
DO_WARMUP = os.getenv("WARMUP", "1") == "1" # set WARMUP=0 to skip the first warmup call
# Optional override: JSON string for LoRA manifest (same shape as loras.json)
LORAS_JSON = os.getenv("LORAS_JSON", "").strip()
# Where snapshot_download caches the repo in the container
REPO_DIR = "/home/user/model"
SCHEDULERS = {
"default": None,
"euler_a": EulerAncestralDiscreteScheduler,
"euler": EulerDiscreteScheduler,
"ddim": DDIMScheduler,
"lms": LMSDiscreteScheduler,
"pndm": PNDMScheduler,
"dpmpp_2m": DPMSolverMultistepScheduler,
}
# Globals populated at startup
pipe = None
IS_SDXL = True
LORA_MANIFEST: Dict[str, Dict[str, str]] = {}
INIT_ERROR: Optional[str] = None
def load_lora_manifest(repo_dir: str) -> Dict[str, Dict[str, str]]:
"""Manifest load order:
1) Environment variable LORAS_JSON (if provided)
2) loras.json inside the downloaded model repo
3) loras.json at the Space root (next to app.py)
4) Built-in fallback with MoriiMee_Gothic you provided
"""
# 1) From env JSON
if LORAS_JSON:
try:
parsed = json.loads(LORAS_JSON)
if isinstance(parsed, dict):
return parsed
except Exception as e:
print(f"[WARN] Failed to parse LORAS_JSON: {e}")
# 2) From repo
repo_manifest = os.path.join(repo_dir, "loras.json")
if os.path.exists(repo_manifest):
try:
with open(repo_manifest, "r", encoding="utf-8") as f:
parsed = json.load(f)
if isinstance(parsed, dict):
return parsed
except Exception as e:
print(f"[WARN] Failed to parse repo loras.json: {e}")
# 3) From Space root
local_manifest = os.path.join(os.getcwd(), "loras.json")
if os.path.exists(local_manifest):
try:
with open(local_manifest, "r", encoding="utf-8") as f:
parsed = json.load(f)
if isinstance(parsed, dict):
return parsed
except Exception as e:
print(f"[WARN] Failed to parse local loras.json: {e}")
# 4) Built-in fallback: your MoriiMee Gothic LoRA
print("[INFO] Using built-in LoRA fallback manifest.")
return {
"MoriiMee_Gothic": {
"repo": "LyliaEngine/MoriiMee_Gothic_Niji_Style_Illustrious_r1",
"weight_name": "MoriiMee_Gothic_Niji_Style_Illustrious_r1.safetensors"
}
}
def bootstrap_model():
"""
Downloads MODEL_REPO_ID into REPO_DIR and loads the single-file checkpoint,
keeping weights on CPU; ZeroGPU attaches GPU only inside @spaces.GPU calls.
"""
global pipe, IS_SDXL, LORA_MANIFEST, INIT_ERROR
INIT_ERROR = None
if not MODEL_REPO_ID or not CHECKPOINT_FILENAME:
INIT_ERROR = "Missing MODEL_REPO_ID or CHECKPOINT_FILENAME."
print(f"[ERROR] {INIT_ERROR}")
return
try:
local_dir = snapshot_download(
repo_id=MODEL_REPO_ID,
token=HF_TOKEN,
local_dir=REPO_DIR,
ignore_patterns=["*.md"],
)
except Exception as e:
INIT_ERROR = f"Failed to download repo {MODEL_REPO_ID}: {e}"
print(f"[ERROR] {INIT_ERROR}")
return
ckpt_path = os.path.join(local_dir, CHECKPOINT_FILENAME)
if not os.path.exists(ckpt_path):
INIT_ERROR = f"Checkpoint not found at {ckpt_path}. Check CHECKPOINT_FILENAME."
print(f"[ERROR] {INIT_ERROR}")
return
try:
# Attempt SDXL first (text_encoder_2 present)
_pipe = StableDiffusionXLPipeline.from_single_file(
ckpt_path, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False
)
sdxl = True
except Exception:
try:
_pipe = StableDiffusionPipeline.from_single_file(
ckpt_path, torch_dtype=torch.float16, use_safetensors=True
)
sdxl = False
except Exception as e:
INIT_ERROR = f"Failed to load pipeline: {e}"
print(f"[ERROR] {INIT_ERROR}")
return
if hasattr(_pipe, "enable_attention_slicing"):
_pipe.enable_attention_slicing("max")
if hasattr(_pipe, "enable_vae_slicing"):
_pipe.enable_vae_slicing()
if hasattr(_pipe, "set_progress_bar_config"):
_pipe.set_progress_bar_config(disable=True)
manifest = load_lora_manifest(local_dir)
print(f"[INFO] LoRAs available: {list(manifest.keys())}")
# Publish
pipe = _pipe
IS_SDXL = sdxl
LORA_MANIFEST = manifest
def apply_loras(selected: List[str], scale: float, repo_dir: str):
if not selected or scale <= 0:
return
for name in selected:
meta = LORA_MANIFEST.get(name)
if not meta:
print(f"[WARN] Requested LoRA '{name}' not in manifest.")
continue
try:
if "path" in meta:
pipe.load_lora_weights(os.path.join(repo_dir, meta["path"]), adapter_name=name)
else:
pipe.load_lora_weights(meta.get("repo", ""), weight_name=meta.get("weight_name"), adapter_name=name)
print(f"[INFO] Loaded LoRA: {name}")
except Exception as e:
print(f"[WARN] LoRA load failed for {name}: {e}")
try:
pipe.set_adapters(selected, adapter_weights=[float(scale)] * len(selected))
print(f"[INFO] Activated LoRAs: {selected} at scale {scale}")
except Exception as e:
print(f"[WARN] set_adapters failed: {e}")
@spaces.GPU
def txt2img(
prompt: str,
negative: str,
width: int,
height: int,
steps: int,
guidance: float,
images: int,
seed: Optional[int],
scheduler: str,
loras: List[str],
lora_scale: float,
fuse_lora: bool,
):
if pipe is None:
raise RuntimeError(f"Model not initialized. {INIT_ERROR or 'Check Space secrets and logs.'}")
local_device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to(local_device)
if scheduler in SCHEDULERS and SCHEDULERS[scheduler] is not None:
try:
pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config)
except Exception as e:
print(f"[WARN] Scheduler switch failed: {e}")
apply_loras(loras, lora_scale, REPO_DIR)
if fuse_lora and loras:
try:
pipe.fuse_lora(lora_scale=float(lora_scale))
except Exception as e:
print(f"[WARN] fuse_lora failed: {e}")
generator = torch.Generator(device=local_device).manual_seed(int(seed)) if seed not in (None, "") else None
kwargs: Dict[str, Any] = dict(
prompt=prompt or "",
negative_prompt=negative or None,
width=int(width),
height=int(height),
num_inference_steps=int(steps),
guidance_scale=float(guidance),
num_images_per_prompt=int(images),
generator=generator,
)
with torch.inference_mode():
out = pipe(**kwargs)
return out.images
def warmup():
try:
_ = txt2img("warmup", "", 512, 512, 4, 4.0, 1, 1234, "default", [], 0.0, False)
except Exception as e:
print(f"[WARN] Warmup failed: {e}")
# UI
with gr.Blocks(title="SDXL Space (ZeroGPU, single-file, LoRA-ready)") as demo:
status = gr.Markdown("")
with gr.Row():
prompt = gr.Textbox(label="Prompt", lines=3)
negative = gr.Textbox(label="Negative Prompt", lines=3)
with gr.Row():
width = gr.Slider(256, 1536, 1024, step=64, label="Width")
height = gr.Slider(256, 1536, 1024, step=64, label="Height")
with gr.Row():
steps = gr.Slider(5, 80, 30, step=1, label="Steps")
guidance = gr.Slider(0.0, 20.0, 6.5, step=0.1, label="Guidance")
images = gr.Slider(1, 4, 1, step=1, label="Images")
with gr.Row():
seed = gr.Number(value=None, precision=0, label="Seed (blank=random)")
scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="dpmpp_2m", label="Scheduler")
lora_names = gr.CheckboxGroup(choices=[], label="LoRAs (from loras.json; select any)")
lora_scale = gr.Slider(0.0, 1.5, 0.7, step=0.05, label="LoRA scale")
fuse = gr.Checkbox(label="Fuse LoRA (faster after load)")
btn = gr.Button("Generate", variant="primary", interactive=False)
gallery = gr.Gallery(columns=4, height=420)
def _startup():
bootstrap_model()
if INIT_ERROR:
return gr.update(value=f"❌ Init failed: {INIT_ERROR}"), gr.update(choices=[]), gr.update(interactive=False)
msg = f"✅ Model loaded from {MODEL_REPO_ID} ({'SDXL' if IS_SDXL else 'SD'})"
# Populate LoRA choices (manifest could come from repo, Space file, or built-in fallback)
return gr.update(value=msg), gr.update(choices=list(LORA_MANIFEST.keys())), gr.update(interactive=True)
demo.load(_startup, outputs=[status, lora_names, btn])
if DO_WARMUP:
demo.load(lambda: warmup(), inputs=None, outputs=None)
btn.click(
txt2img,
inputs=[prompt, negative, width, height, steps, guidance, images, seed, scheduler, lora_names, lora_scale, fuse],
outputs=[gallery],
api_name="txt2img",
concurrency_limit=1,
concurrency_id="gpu_queue",
)
demo.queue(max_size=32, default_concurrency_limit=1).launch()