revi13's picture
Update app.py
fd6d685 verified
# app.py — InstantID × Beautiful Realistic Asians v7 (ZeroGPU-friendly, persistent cache)
"""Persistent-cache backend for InstantID portrait generation.
* 依存モデルは /data が書込可ならそこへ、それ以外は ~/.cache に保存
* wget を使った簡易リトライ DL
"""
# --- ★ Monkey-Patch: torchvision 0.17+ で消えた functional_tensor を補完 ---
import types, sys
from torchvision.transforms import functional as F
mod = types.ModuleType("torchvision.transforms.functional_tensor")
# 必要なのは rgb_to_grayscale だけなのでこれだけエイリアス
mod.rgb_to_grayscale = F.rgb_to_grayscale
sys.modules["torchvision.transforms.functional_tensor"] = mod
# ---------------------------------------------------------------------------
import os, subprocess, cv2, torch, spaces, gradio as gr, numpy as np
from pathlib import Path
from PIL import Image
from diffusers import (
StableDiffusionPipeline, ControlNetModel,
DPMSolverMultistepScheduler, AutoencoderKL,
)
from compel import Compel
from insightface.app import FaceAnalysis
##############################################################################
# 0. キャッシュ用ディレクトリ
##############################################################################
PERSIST_BASE = Path("/data")
CACHE_ROOT = (
PERSIST_BASE / "instantid_cache"
if PERSIST_BASE.exists() and os.access(PERSIST_BASE, os.W_OK)
else Path.home() / ".cache" / "instantid_cache"
)
print("cache →", CACHE_ROOT)
MODELS_DIR = CACHE_ROOT / "models"
LORA_DIR = MODELS_DIR / "Lora" # FaceID LoRA などを置く
EMB_DIR = CACHE_ROOT / "embeddings"
UPSCALE_DIR = CACHE_ROOT / "realesrgan"
for p in (MODELS_DIR, LORA_DIR, EMB_DIR, UPSCALE_DIR):
p.mkdir(parents=True, exist_ok=True)
def dl(url: str, dst: Path, attempts: int = 2):
"""wget + リトライの簡易ダウンローダ"""
if dst.exists():
print("✓", dst.relative_to(CACHE_ROOT)); return
for i in range(1, attempts + 1):
print(f"⬇ {dst.name} (try {i}/{attempts})")
if subprocess.call(["wget", "-q", "-O", str(dst), url]) == 0:
return
raise RuntimeError(f"download failed → {url}")
##############################################################################
# 1. 必要アセットのダウンロード
##############################################################################
print("— asset check —")
# 1-A. ベース checkpoint
BASE_CKPT = MODELS_DIR / "beautiful_realistic_asians_v7_fp16.safetensors"
dl(
"https://civitai.com/api/download/models/177164?type=Model&format=SafeTensor&size=pruned&fp=fp16",
BASE_CKPT,
)
# 1-B. FaceID LoRA(Δのみ)
LORA_FILE = LORA_DIR / "ip-adapter-faceid-plusv2_sd15_lora.safetensors"
dl(
"https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sd15_lora.safetensors",
LORA_FILE,
)
# 1-C. textual inversion Embeddings
EMB_URLS = {
"ng_deepnegative_v1_75t.pt": [
"https://huggingface.co/datasets/gsdf/EasyNegative/resolve/main/ng_deepnegative_v1_75t.pt",
"https://huggingface.co/mrpxl2/animetarotV51.safetensors/raw/cc3008c0148061896549a995cc297aef0af4ef1b/ng_deepnegative_v1_75t.pt",
],
"badhandv4.pt": [
"https://huggingface.co/datasets/gsdf/ConceptLab/resolve/main/badhandv4.pt",
"https://huggingface.co/nolanaatama/embeddings/raw/main/badhandv4.pt",
],
"CyberRealistic_Negative-neg.pt": [
"https://huggingface.co/datasets/gsdf/ConceptLab/resolve/main/CyberRealistic_Negative-neg.pt",
"https://huggingface.co/wsj1995/embeddings/raw/main/CyberRealistic_Negative-neg.civitai.info",
],
"UnrealisticDream.pt": [
"https://huggingface.co/datasets/gsdf/ConceptLab/resolve/main/UnrealisticDream.pt",
"https://huggingface.co/imagepipeline/UnrealisticDream/raw/main/f84133b4-aad8-44be-b9ce-7e7e3a8c111f.pt",
],
}
for fname, urls in EMB_URLS.items():
dst = EMB_DIR / fname
for idx, u in enumerate(urls, 1):
try:
dl(u, dst); break
except RuntimeError:
if idx == len(urls): raise
print(" ↳ fallback URL …")
# 1-D. Real-ESRGAN weights (×8)
RRG_WEIGHTS = UPSCALE_DIR / "RealESRGAN_x8plus.pth"
RRG_URLS = [
"https://huggingface.co/NoCrypt/Superscale_RealESRGAN/resolve/main/RealESRGAN_x8plus.pth",
"https://huggingface.co/ai-forever/Real-ESRGAN/raw/main/RealESRGAN_x8.pth",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/8x_NMKD-Superscale_100k.pth",
]
for idx, link in enumerate(RRG_URLS, 1):
try:
dl(link, RRG_WEIGHTS); break
except RuntimeError:
if idx == len(RRG_URLS): raise
print(" ↳ fallback URL …")
##############################################################################
# 2. ランタイム初期化
##############################################################################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
print("device:", device, "| dtype:", dtype)
providers = (
["CUDAExecutionProvider", "CPUExecutionProvider"]
if torch.cuda.is_available()
else ["CPUExecutionProvider"]
)
face_app = FaceAnalysis(name="buffalo_l", root=str(CACHE_ROOT), providers=providers)
face_app.prepare(ctx_id=(0 if torch.cuda.is_available() else -1), det_size=(640, 640))
# ControlNet + SD パイプライン
#controlnet = ControlNetModel.from_pretrained(
# "InstantX/InstantID", subfolder="ControlNetModel", torch_dtype=dtype
#)
pipe = StableDiffusionPipeline.from_single_file(
BASE_CKPT, torch_dtype=dtype, safety_checker=None, use_safetensors=True, clip_skip=2
)
pipe.vae = AutoencoderKL.from_pretrained(
"stabilityai/sd-vae-ft-mse", torch_dtype=dtype
).to(device)
#pipe.controlnet = controlnet
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++"
)
# --- ここが核心:画像エンコーダ込みで公式レポから直接ロード ------------------
pipe.load_ip_adapter(
"h94/IP-Adapter", # Hugging Face Hub ID
subfolder="models", # ip-adapter-plus-face_sd15.bin が入っているフォルダ
weight_name="ip-adapter-plus-face_sd15.bin",
)
# ---------------------------------------------------------------------------
# FaceID LoRA(差分 LoRA のみ)
pipe.load_lora_weights(str(LORA_DIR), weight_name=LORA_FILE.name)
pipe.set_ip_adapter_scale(0.65)
# textual inversion 読み込み
for emb in EMB_DIR.glob("*.*"):
try:
pipe.load_textual_inversion(emb, token=emb.stem)
print("emb loaded →", emb.stem)
except Exception:
print("emb skip →", emb.name)
pipe.to(device)
# compel プロセッサを初期化
compel_proc = Compel(
tokenizer=pipe.tokenizer,
text_encoder=pipe.text_encoder,
truncate_long_prompts=False # 長いプロンプトを切り捨てない
)
print("pipeline ready ✔")
##############################################################################
# 3. アップスケーラ
##############################################################################
try:
from basicsr.archs.rrdb_arch import RRDBNet
try:
from realesrgan import RealESRGAN
except ImportError:
from realesrgan import RealESRGANer as RealESRGAN
rrdb = RRDBNet(3, 3, 64, 23, 32, scale=8)
upsampler = RealESRGAN(device, rrdb, scale=8)
upsampler.load_weights(str(RRG_WEIGHTS))
UPSCALE_OK = True
except Exception as e:
print("Real-ESRGAN disabled →", e)
UPSCALE_OK = False
##############################################################################
# 4. プロンプト & 生成関数
##############################################################################
BASE_PROMPT = (
"Cinematic photo, (best quality:1.1), ultra-realistic, photorealistic of {subject}, "
"natural skin texture, bokeh, standing, front view, full body shot, thighs, "
"Canon EOS R5, 85 mm, f/1.4, ISO 200, 1/160 s, RAW"
)
NEG_PROMPT = (
"ng_deepnegative_v1_75t, BadDream:0.6, UnrealisticDream:0.8, badhandv4:0.9, "
"(worst quality:2), (low quality:1.8), lowres, blurry, jpeg artifacts, "
"painting, sketch, illustration, cartoon, anime, cgi, render, 3d, "
"monochrome, grayscale, text, logo, watermark, signature, username, "
"bad anatomy, malformed, deformed, extra limbs, fused fingers, missing fingers, "
"missing arms, missing legs, skin blemishes, acne, age spot"
)
@spaces.GPU(duration=60)
def generate(
face_np, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor,
progress=gr.Progress(track_tqdm=True),
):
print("🚀 リクエスト受信!")
print("face_np shape:", getattr(face_np, "shape", "None"))
print("subject:", subject)
print("add_prompt:", add_prompt)
print("add_neg:", add_neg)
print("cfg:", cfg)
print("ip_scale:", ip_scale)
print("steps:", steps)
print("width x height:", w, "x", h)
print("upscale:", upscale, "×", up_factor)
if face_np is None or face_np.size == 0:
raise gr.Error("顔画像をアップロードしてください。")
prompt = BASE_PROMPT.format(subject=(subject.strip() or "a beautiful 20yo woman"))
if add_prompt:
prompt += ", " + add_prompt
neg = NEG_PROMPT + (", " + add_neg if add_neg else "")
pipe.set_ip_adapter_scale(ip_scale)
img_in = Image.fromarray(face_np)
# compelで長さを揃え、.unsqueeze(0)でバッチ次元を追加する
prompt_embeds, negative_prompt_embeds = compel_proc([prompt, neg])
prompt_embeds = prompt_embeds.unsqueeze(0)
negative_prompt_embeds = negative_prompt_embeds.unsqueeze(0)
result = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
ip_adapter_image=img_in,
#image=img_in,
#controlnet_conditioning_scale=0.9,
num_inference_steps=int(steps) + 5,
guidance_scale=cfg,
width=int(w),
height=int(h),
).images[0]
if upscale:
if UPSCALE_OK:
up, _ = upsampler.enhance(
cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR), outscale=up_factor
)
result = Image.fromarray(cv2.cvtColor(up, cv2.COLOR_BGR2RGB))
else:
result = result.resize(
(int(result.width * up_factor), int(result.height * up_factor)),
Image.LANCZOS,
)
return result
##############################################################################
# 5. Gradio UI
##############################################################################
# FastAPI app
app = FastAPI()
@app.post("/generate")
async def generate_image_api(request: Request):
try:
body = await request.json()
image_base64 = body.get("image")
subject = body.get("subject", "beautiful woman")
add_prompt = body.get("add_prompt", "")
add_neg = body.get("add_neg", "")
cfg = float(body.get("cfg", 6.0))
ip_scale = float(body.get("ip_scale", 0.65))
steps = int(body.get("steps", 20))
width = int(body.get("width", 512))
height = int(body.get("height", 768))
upscale = bool(body.get("upscale", False))
up_factor = int(body.get("up_factor", 2))
# decode image
if not image_base64:
raise ValueError("Base64画像がありません")
img_bytes = base64.b64decode(image_base64.split(",")[-1])
pil_image = Image.open(BytesIO(img_bytes)).convert("RGB")
np_image = np.array(pil_image)
result_image = generate(
face_np=np_image,
subject=subject,
add_prompt=add_prompt,
add_neg=add_neg,
cfg=cfg,
ip_scale=ip_scale,
steps=steps,
w=width,
h=height,
upscale=upscale,
up_factor=up_factor,
)
buffer = BytesIO()
result_image.save(buffer, format="PNG")
result_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return {"data": ["data:image/png;base64," + result_base64]}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
# Gradio UIをマウント
from app import demo # もともとの Gradio UI
app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)