DB2169's picture
Update app.py
9a97e3a verified
import os, json
from typing import List, Dict, Any, Optional
from PIL import Image
import torch
import gradio as gr
import spaces
from huggingface_hub import snapshot_download
from diffusers import (
StableDiffusionPipeline, # SD 1.x/2.x single-file loader
StableDiffusionXLPipeline, # SDXL single-file loader
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
DDIMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
# -------- Config --------
MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "DB2169/mixy").strip()
CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", "realismIllustriousBy_v50FP16.safetensors").strip()
HF_TOKEN = os.getenv("HF_TOKEN", None)
DO_WARMUP = os.getenv("WARMUP", "1") == "1"
LORAS_JSON = os.getenv("LORAS_JSON", "").strip()
REPO_DIR = "/home/user/model"
SCHEDULERS = {
"default": None,
"euler_a": EulerAncestralDiscreteScheduler,
"euler": EulerDiscreteScheduler,
"ddim": DDIMScheduler,
"lms": LMSDiscreteScheduler,
"pndm": PNDMScheduler,
"dpmpp_2m": DPMSolverMultistepScheduler,
}
# -------- Globals --------
pipe = None
IS_SDXL = False
LORA_MANIFEST: Dict[str, Dict[str, str]] = {}
INIT_ERROR: Optional[str] = None
# -------- Helpers --------
def load_lora_manifest(repo_dir: str) -> Dict[str, Dict[str, str]]:
if LORAS_JSON:
try:
parsed = json.loads(LORAS_JSON)
if isinstance(parsed, dict):
return parsed
except Exception as e:
print(f"[WARN] Failed to parse LORAS_JSON: {e}")
repo_manifest = os.path.join(repo_dir, "loras.json")
if os.path.exists(repo_manifest):
try:
with open(repo_manifest, "r", encoding="utf-8") as f:
parsed = json.load(f)
if isinstance(parsed, dict):
return parsed
except Exception as e:
print(f"[WARN] Failed to parse repo loras.json: {e}")
local_manifest = os.path.join(os.getcwd(), "loras.json")
if os.path.exists(local_manifest):
try:
with open(local_manifest, "r", encoding="utf-8") as f:
parsed = json.load(f)
if isinstance(parsed, dict):
return parsed
except Exception as e:
print(f"[WARN] Failed to parse local loras.json: {e}")
print("[INFO] Using built-in LoRA fallback manifest.")
return {
"MoriiMee_Gothic": {
"repo": "LyliaEngine/MoriiMee_Gothic_Niji_Style_Illustrious_r1",
"weight_name": "MoriiMee_Gothic_Niji_Style_Illustrious_r1.safetensors"
}
}
# -------- Bootstrap (CPU) --------
def bootstrap_model():
"""
Try SD (1.x/2.x) single-file first, then SDXL single-file, to maximize compatibility
with older diffusers that don’t expose DiffusionPipeline.from_single_file.
"""
global pipe, IS_SDXL, LORA_MANIFEST, INIT_ERROR
INIT_ERROR = None
if not MODEL_REPO_ID or not CHECKPOINT_FILENAME:
INIT_ERROR = "Missing MODEL_REPO_ID or CHECKPOINT_FILENAME."
print(f"[ERROR] {INIT_ERROR}")
return
try:
local_dir = snapshot_download(
repo_id=MODEL_REPO_ID,
token=HF_TOKEN,
local_dir=REPO_DIR,
ignore_patterns=["*.md"],
)
except Exception as e:
INIT_ERROR = f"Failed to download repo {MODEL_REPO_ID}: {e}"
print(f"[ERROR] {INIT_ERROR}")
return
ckpt_path = os.path.join(local_dir, CHECKPOINT_FILENAME)
if not os.path.exists(ckpt_path):
INIT_ERROR = f"Checkpoint not found at {ckpt_path}. Check CHECKPOINT_FILENAME."
print(f"[ERROR] {INIT_ERROR}")
return
_pipe = None
_is_sdxl = False
# 1) SD 1.x/2.x first (most single-file merges are SD), then SDXL
try:
_pipe = StableDiffusionPipeline.from_single_file(
ckpt_path, torch_dtype=torch.float16, use_safetensors=True
)
_is_sdxl = False
except Exception as e_sd:
print(f"[INFO] SD load failed or not SD: {e_sd}")
try:
_pipe = StableDiffusionXLPipeline.from_single_file(
ckpt_path, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False
)
_is_sdxl = True
except Exception as e_sdxl:
INIT_ERROR = f"Failed to load pipeline (SD and SDXL): SD={e_sd} | SDXL={e_sdxl}"
print(f"[ERROR] {INIT_ERROR}")
return
if hasattr(_pipe, "enable_attention_slicing"):
_pipe.enable_attention_slicing("max")
if hasattr(_pipe, "enable_vae_slicing"):
_pipe.enable_vae_slicing()
if hasattr(_pipe, "set_progress_bar_config"):
_pipe.set_progress_bar_config(disable=True)
manifest = load_lora_manifest(local_dir)
print(f"[INFO] LoRAs available: {list(manifest.keys())}")
pipe = _pipe
IS_SDXL = _is_sdxl
LORA_MANIFEST = manifest
def apply_loras(selected: List[str], scale: float, repo_dir: str):
if not selected or scale <= 0:
return
for name in selected:
meta = LORA_MANIFEST.get(name)
if not meta:
print(f"[WARN] Requested LoRA '{name}' not in manifest.")
continue
try:
if "path" in meta:
pipe.load_lora_weights(os.path.join(repo_dir, meta["path"]), adapter_name=name)
else:
pipe.load_lora_weights(meta.get("repo", ""), weight_name=meta.get("weight_name"), adapter_name=name)
print(f"[INFO] Loaded LoRA: {name}")
except Exception as e:
print(f"[WARN] LoRA load failed for {name}: {e}")
try:
pipe.set_adapters(selected, adapter_weights=[float(scale)] * len(selected))
print(f"[INFO] Activated LoRAs: {selected} at scale {scale}")
except Exception as e:
print(f"[WARN] set_adapters failed: {e}")
# -------- Generation (ZeroGPU) --------
@spaces.GPU
def txt2img(
prompt: str,
negative: str,
width: int,
height: int,
steps: int,
guidance: float,
images: int,
seed: Optional[int],
scheduler: str,
loras: List[str],
lora_scale: float,
fuse_lora: bool,
):
if pipe is None:
raise RuntimeError(f"Model not initialized. {INIT_ERROR or 'Check Space secrets and logs.'}")
local_device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to(local_device)
if scheduler in SCHEDULERS and SCHEDULERS[scheduler] is not None:
try:
pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config)
except Exception as e:
print(f"[WARN] Scheduler switch failed: {e}")
apply_loras(loras, lora_scale, REPO_DIR)
if fuse_lora and loras:
try:
pipe.fuse_lora(lora_scale=float(lora_scale))
except Exception as e:
print(f"[WARN] fuse_lora failed: {e}")
generator = torch.Generator(device=local_device).manual_seed(int(seed)) if seed not in (None, "") else None
kwargs: Dict[str, Any] = dict(
prompt=prompt or "",
negative_prompt=negative or None,
width=int(width),
height=int(height),
num_inference_steps=int(steps),
guidance_scale=float(guidance),
num_images_per_prompt=int(images),
generator=generator,
)
with torch.inference_mode():
out = pipe(**kwargs)
return out.images
# -------- UI --------
with gr.Blocks(title="SDXL/SD single-file (ZeroGPU, LoRA-ready)") as demo:
status = gr.Markdown("")
with gr.Row():
prompt = gr.Textbox(label="Prompt", lines=3)
negative = gr.Textbox(label="Negative Prompt", lines=3)
with gr.Row():
width = gr.Slider(256, 1536, 1024, step=64, label="Width")
height = gr.Slider(256, 1536, 1024, step=64, label="Height")
with gr.Row():
steps = gr.Slider(5, 80, 30, step=1, label="Steps")
guidance = gr.Slider(0.0, 20.0, 6.5, step=0.1, label="Guidance")
images = gr.Slider(1, 4, 1, step=1, label="Images")
with gr.Row():
seed = gr.Number(value=None, precision=0, label="Seed (blank=random)")
scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="dpmpp_2m", label="Scheduler")
lora_names = gr.CheckboxGroup(choices=[], label="LoRAs (from loras.json; select any)")
lora_scale = gr.Slider(0.0, 1.5, 0.7, step=0.05, label="LoRA scale")
fuse = gr.Checkbox(label="Fuse LoRA (faster after load)")
btn = gr.Button("Generate", variant="primary", interactive=False)
gallery = gr.Gallery(columns=4, height=420)
def _startup():
bootstrap_model()
if INIT_ERROR:
return (
gr.update(value=f"❌ Init failed: {INIT_ERROR}"),
gr.update(choices=[]),
gr.update(value=1024, minimum=256, maximum=1536, step=64),
gr.update(value=1024, minimum=256, maximum=1536, step=64),
gr.update(interactive=False),
)
default_wh = 1024 if IS_SDXL else 512
msg = f"✅ Model loaded from {MODEL_REPO_ID} ({'SDXL' if IS_SDXL else 'SD'})"
# Warm up only after model is ready (avoids race)
if DO_WARMUP:
try:
_ = txt2img("warmup", "", default_wh, default_wh, 4, 4.0, 1, 1234, "default", [], 0.0, False)
except Exception as e:
print(f"[WARN] Warmup failed: {e}")
return (
gr.update(value=msg),
gr.update(choices=list(LORA_MANIFEST.keys())),
gr.update(value=default_wh, minimum=256, maximum=1536, step=64),
gr.update(value=default_wh, minimum=256, maximum=1536, step=64),
gr.update(interactive=True),
)
demo.load(_startup, outputs=[status, lora_names, width, height, btn])
btn.click(
txt2img,
inputs=[prompt, negative, width, height, steps, guidance, images, seed, scheduler, lora_names, lora_scale, fuse],
outputs=[gallery],
api_name="txt2img",
concurrency_limit=1,
concurrency_id="gpu_queue",
)
demo.queue(max_size=32, default_concurrency_limit=1).launch()