pixagram-dev / models.py
primerz's picture
Update models.py
e036d10 verified
"""
Model loading and initialization for Pixagram AI Pixel Art Generator
FIXED VERSION - Uses correct InstantID pipeline and Compel encoder
"""
import torch
import time
import os
from diffusers import (
ControlNetModel,
AutoencoderKL,
LCMScheduler
)
from transformers import (
CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
)
from insightface.app import FaceAnalysis
from controlnet_aux import ZoeDetector, OpenposeDetector, LeresDetector, MidasDetector, MediapipeFaceDetector
from huggingface_hub import hf_hub_download, snapshot_download
# --- START FIX: Import correct pipeline and Compel ---
from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline
from compel import Compel, ReturnedEmbeddingsType
# --- END FIX ---
from config import (
device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
FACE_DETECTION_CONFIG, CLIP_SKIP, DOWNLOAD_CONFIG
)
# (We keep download_model_with_retry, load_face_analysis, load_depth_detector,
# load_openpose_detector, and load_mediapipe_face_detector as they were)
# ... (Keep all original functions from line 25 down to line 180) ...
def download_model_with_retry(repo_id, filename, max_retries=None, **kwargs):
"""Download model with retry logic and proper token handling."""
if max_retries is None:
max_retries = DOWNLOAD_CONFIG['max_retries']
# Ensure token is passed if available
if HUGGINGFACE_TOKEN and "token" not in kwargs:
kwargs["token"] = HUGGINGFACE_TOKEN
for attempt in range(max_retries):
try:
print(f" Attempting to download {filename} (attempt {attempt + 1}/{max_retries})...")
return hf_hub_download(
repo_id=repo_id,
filename=filename,
**kwargs
)
except Exception as e:
print(f" [WARNING] Download attempt {attempt + 1} failed: {e}")
if attempt < max_retries - 1:
print(f" Retrying in {DOWNLOAD_CONFIG['retry_delay']} seconds...")
time.sleep(DOWNLOAD_CONFIG['retry_delay'])
else:
print(f" [ERROR] Failed to download {filename} after {max_retries} attempts")
raise
return None
def load_face_analysis():
"""
Load face analysis model with proper model downloading from HuggingFace.
Downloads from DIAMONIK7777/antelopev2 which has the correct model structure.
"""
print("Loading face analysis model...")
try:
antelope_download = snapshot_download(repo_id="DIAMONIK7777/antelopev2", local_dir="/data/models/antelopev2")
# --- FIX: Load InsightFace on CPU to save VRAM ---
face_app = FaceAnalysis(name='antelopev2', root='/data', providers=['CPUExecutionProvider'])
face_app.prepare(ctx_id=0, det_size=(640, 640))
print(" [OK] Face analysis loaded (on CPU)")
return face_app, True
except Exception as e:
print(f" [ERROR] Face detection not available: {e}")
import traceback
traceback.print_exc()
return None, False
def load_depth_detector():
"""
Load depth detector with fallback hierarchy: Leres → Zoe → Midas.
Returns (detector, detector_type, success).
"""
print("Loading depth detector with fallback hierarchy...")
# Try LeresDetector first (best quality)
try:
print(" Attempting LeresDetector (highest quality)...")
# --- FIX: Load on CPU ---
leres_depth = LeresDetector.from_pretrained("lllyasviel/Annotators")
# leres_depth.to(device) # Removed
print(" [OK] LeresDetector loaded successfully (on CPU)")
return leres_depth, 'leres', True
except Exception as e:
print(f" [INFO] LeresDetector not available: {e}")
# Fallback to ZoeDetector
try:
print(" Attempting ZoeDetector (fallback #1)...")
# --- FIX: Load on CPU ---
zoe_depth = ZoeDetector.from_pretrained("lllyasviel/Annotators")
# zoe_depth.to(device) # Removed
print(" [OK] ZoeDetector loaded successfully (on CPU)")
return zoe_depth, 'zoe', True
except Exception as e:
print(f" [INFO] ZoeDetector not available: {e}")
# Final fallback to MidasDetector
try:
print(" Attempting MidasDetector (fallback #2)...")
# --- FIX: Load on CPU ---
midas_depth = MidasDetector.from_pretrained("lllyasviel/Annotators")
# midas_depth.to(device) # Removed
print(" [OK] MidasDetector loaded successfully (on CPU)")
return midas_depth, 'midas', True
except Exception as e:
print(f" [WARNING] MidasDetector not available: {e}")
print(" [ERROR] No depth detector available")
return None, None, False
# --- NEW FUNCTION ---
def load_openpose_detector():
"""Load OpenPose detector."""
print("Loading OpenPose detector...")
try:
# --- FIX: Load on CPU ---
openpose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
# openpose.to(device) # Removed
print(" [OK] OpenPose loaded successfully (on CPU)")
return openpose, True
except Exception as e:
print(f" [WARNING] OpenPose not available: {e}")
return None, False
# --- END NEW FUNCTION ---
# --- NEW FUNCTION ---
def load_mediapipe_face_detector():
"""Load MediapipeFaceDetector for advanced face detection."""
print("Loading MediapipeFaceDetector...")
try:
face_detector = MediapipeFaceDetector()
print(" [OK] MediapipeFaceDetector loaded successfully")
return face_detector, True
except Exception as e:
print(f" [WARNING] MediapipeFaceDetector not available: {e}")
return None, False
# --- END NEW FUNCTION ---
def load_controlnets():
"""Load ControlNet models."""
print("Loading ControlNet Zoe Depth model...")
# --- FIX: Load core models on GPU ---
controlnet_depth = ControlNetModel.from_pretrained(
"xinsir/controlnet-depth-sdxl-1.0",
torch_dtype=dtype
).to(device)
print(" [OK] ControlNet Depth loaded (on GPU)")
# --- NEW: Load OpenPose ControlNet ---
print("Loading ControlNet OpenPose model...")
try:
# --- FIX: Load core models on GPU ---
controlnet_openpose = ControlNetModel.from_pretrained(
"xinsir/controlnet-openpose-sdxl-1.0",
torch_dtype=dtype
).to(device)
print(" [OK] ControlNet OpenPose loaded (on GPU)")
except Exception as e:
print(f" [WARNING] ControlNet OpenPose not available: {e}")
controlnet_openpose = None
# --- END NEW ---
print("Loading InstantID ControlNet...")
try:
# --- FIX: Load core models on GPU ---
controlnet_instantid = ControlNetModel.from_pretrained(
"InstantX/InstantID",
subfolder="ControlNetModel",
torch_dtype=dtype
).to(device)
print(" [OK] InstantID ControlNet loaded successfully (on GPU)")
# Return all three models
return controlnet_depth, controlnet_instantid, controlnet_openpose, True
except Exception as e:
print(f" [WARNING] InstantID ControlNet not available: {e}")
# Return models, indicating InstantID failure
return controlnet_depth, None, controlnet_openpose, False
# --- START: REMOVED load_image_encoder ---
# (The new pipeline handles this internally)
# --- END: REMOVED load_image_encoder ---
def load_sdxl_pipeline(controlnets):
"""Load SDXL checkpoint from HuggingFace Hub."""
print("Loading SDXL checkpoint (horizon) with bundled VAE from HuggingFace Hub...")
# --- START FIX: Load base text models for Compel (from previous fix) ---
print(" Loading base tokenizers and text encoders...")
BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
tokenizer = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer")
tokenizer_2 = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer_2")
text_encoder = CLIPTextModel.from_pretrained(
BASE_MODEL, subfolder="text_encoder", torch_dtype=dtype
).to(device)
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
BASE_MODEL, subfolder="text_encoder_2", torch_dtype=dtype
).to(device)
print(" [OK] Base text/token models loaded")
# --- END FIX ---
try:
model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'], repo_type="model")
# --- START FIX: Load the CORRECT pipeline ---
pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
model_path,
controlnet=controlnets,
torch_dtype=dtype,
use_safetensors=True,
# Pass components
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
).to(device)
# --- END FIX ---
print(" [OK] Custom checkpoint loaded successfully (VAE bundled)")
return pipe, True
except Exception as e:
print(f" [WARNING] Could not load custom checkpoint: {e}")
print(" Using default SDXL base model")
# --- START FIX: Fallback to the CORRECT pipeline ---
pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnets,
torch_dtype=dtype,
use_safetensors=True,
# Pass components
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
).to(device)
# --- END FIX ---
return pipe, False
def load_loras(pipe):
"""Load all LORAs from HuggingFace Hub."""
print("Loading all LORAs from HuggingFace Hub...")
loaded_loras = {}
lora_files = {
"retroart": MODEL_FILES.get("lora_retroart"),
"vga": MODEL_FILES.get("lora_vga"),
"lucasart": MODEL_FILES.get("lora_lucasart")
}
for adapter_name, filename in lora_files.items():
if not filename:
print(f" [INFO] No file specified for LORA '{adapter_name}', skipping.")
loaded_loras[adapter_name] = False
continue
try:
lora_path = download_model_with_retry(MODEL_REPO, filename, repo_type="model")
pipe.load_lora_weights(lora_path, adapter_name=adapter_name)
print(f" [OK] LORA loaded successfully: {filename} as '{adapter_name}'")
loaded_loras[adapter_name] = True
except Exception as e:
print(f" [WARNING] Could not load LORA {filename}: {e}")
loaded_loras[adapter_name] = False
success = any(loaded_loras.values())
if not success:
print(" [WARNING] No LORAs were loaded successfully.")
return loaded_loras, success
# --- START FIX: Replace setup_ip_adapter ---
def setup_ip_adapter(pipe):
"""
Setup IP-Adapter for InstantID face embeddings using the pipeline's method.
"""
print("Setting up IP-Adapter for InstantID face embeddings...")
try:
# Download InstantID weights
ip_adapter_path = download_model_with_retry(
"InstantX/InstantID",
"ip-adapter.bin",
repo_type="model"
)
# Use the pipeline's built-in loader
pipe.load_ip_adapter_instantid(ip_adapter_path)
print(" [OK] IP-Adapter fully loaded via pipeline")
return None, True # We don't need to return a model
except Exception as e:
print(f" [ERROR] Could not setup IP-Adapter: {e}")
import traceback
traceback.print_exc()
return None, False
# --- END FIX ---
# --- START FIX: Replace setup_cappella with setup_compel ---
def setup_compel(pipe):
"""Setup Compel for robust prompt encoding."""
print("Setting up Compel (prompt encoder)...")
try:
compel = Compel(
tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=[False, True]
)
print(" [OK] Compel loaded successfully.")
return compel, True
except Exception as e:
print(f" [WARNING] Compel not available: {e}")
return None, False
# --- END FIX ---
def setup_scheduler(pipe):
"""Setup LCM scheduler."""
print("Setting up LCM scheduler...")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
print(" [OK] LCM scheduler configured")
def optimize_pipeline(pipe):
"""Apply optimizations to pipeline."""
if device == "cuda":
try:
pipe.enable_xformers_memory_efficient_attention()
print(" [OK] xformers enabled")
except Exception as e:
print(f" [INFO] xformers not available: {e}")
def load_caption_model():
"""
Load caption model with proper error handling.
Tries multiple models in order of quality.
"""
print("Loading caption model...")
# Try GIT-Large first (good balance of quality and compatibility)
try:
from transformers import AutoProcessor, AutoModelForCausalLM
print(" Attempting GIT-Large (recommended)...")
caption_processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
caption_model = AutoModelForCausalLM.from_pretrained(
"microsoft/git-large-coco",
torch_dtype=dtype
)
print(" [OK] GIT-Large model loaded (produces detailed captions, on CPU)")
return caption_processor, caption_model, True, 'git'
except Exception as e1:
print(f" [INFO] GIT-Large not available: {e1}")
# Try BLIP base as fallback
try:
from transformers import BlipProcessor, BlipForConditionalGeneration
print(" Attempting BLIP base (fallback)...")
caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
caption_model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base",
torch_dtype=dtype
)
print(" [OK] BLIP base model loaded (standard captions, on CPU)")
return caption_processor, caption_model, True, 'blip'
except Exception as e2:
print(f" [WARNING] Caption models not available: {e2}")
print(" Caption generation will be disabled")
return None, None, False, 'none'
def set_clip_skip(pipe):
"""Set CLIP skip value."""
if hasattr(pipe, 'text_encoder'):
print(f" [OK] CLIP skip set to {CLIP_SKIP}")
print("[OK] Model loading functions ready")