VibeSpace / ipadapter_model.py
huzey's picture
Add CPU fallback for HF custom feature stages
22afff9
"""
IP-Adapter Model Interface
This module provides utilities for working with IP-Adapter models, including:
- Loading Stable Diffusion pipelines with IP-Adapter
- Extracting CLIP embeddings from images
- Generating images from CLIP embeddings
- Utility functions for image processing
"""
import logging
import os
from typing import List, Optional, Union, Tuple
import numpy as np
import torch
from PIL import Image
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, DDIMScheduler, AutoencoderKL
# Fix for torch 2.5.0 compatibility
torch.backends.cuda.enable_cudnn_sdp(False)
from ip_adapter import IPAdapterPlus, IPAdapterPlusXL
# ===== SD1.5 Base Model Resolution =====
DEFAULT_SD15_BASE_MODEL = "SG161222/Realistic_Vision_V4.0_noVAE"
DEFAULT_SD15_FALLBACK_MODELS = (
"runwayml/stable-diffusion-v1-5",
"stable-diffusion-v1-5/stable-diffusion-v1-5",
)
SD15_BASE_MODEL_ENV = "VIBESPACE_SD15_BASE_MODEL"
SD15_FALLBACK_MODELS_ENV = "VIBESPACE_SD15_FALLBACK_MODELS"
def _get_sd15_torch_dtype(device: str) -> torch.dtype:
"""Keep SD1.5 in fp16 on CUDA and use fp32 elsewhere for numerical stability."""
return torch.float16 if str(device).lower().startswith("cuda") else torch.float32
def _get_default_runtime_device() -> str:
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
def _get_ip_model_dtype(ip_model) -> torch.dtype:
return getattr(ip_model, "torch_dtype", torch.float16)
def _get_sd15_base_model_candidates() -> List[str]:
"""Build an ordered, de-duplicated list of SD1.5 base model candidates."""
configured_primary = os.getenv(SD15_BASE_MODEL_ENV, "").strip()
configured_fallbacks = os.getenv(SD15_FALLBACK_MODELS_ENV, "").strip()
candidates: List[str] = []
if configured_primary:
candidates.append(configured_primary)
candidates.append(DEFAULT_SD15_BASE_MODEL)
if configured_fallbacks:
candidates.extend(
model_id.strip()
for model_id in configured_fallbacks.split(",")
if model_id.strip()
)
else:
candidates.extend(DEFAULT_SD15_FALLBACK_MODELS)
deduplicated_candidates: List[str] = []
seen = set()
for model_id in candidates:
if model_id in seen:
continue
deduplicated_candidates.append(model_id)
seen.add(model_id)
return deduplicated_candidates
def _load_sd15_base_pipeline(
noise_scheduler: DDIMScheduler,
vae: AutoencoderKL,
torch_dtype: torch.dtype,
) -> StableDiffusionPipeline:
"""Load the first available SD1.5-compatible base pipeline."""
candidates = _get_sd15_base_model_candidates()
last_error: Optional[Exception] = None
for index, base_model_path in enumerate(candidates):
try:
return StableDiffusionPipeline.from_pretrained(
base_model_path,
torch_dtype=torch_dtype,
scheduler=noise_scheduler,
vae=vae,
feature_extractor=None,
safety_checker=None,
)
except Exception as exc: # noqa: BLE001
last_error = exc
if index < len(candidates) - 1:
logging.warning(
"Failed to load SD1.5 base model '%s'; trying fallback. Error: %s",
base_model_path,
exc,
)
candidate_list = ", ".join(candidates)
raise RuntimeError(
f"Failed to load any SD1.5 base model. Tried: {candidate_list}"
) from last_error
# ===== Image Utility Functions =====
def create_image_grid(images: List[Image.Image], rows: int, cols: int) -> Image.Image:
# Get dimensions from first image (assumes all images are same size)
width, height = images[0].size
# Create empty grid canvas
grid = Image.new('RGB', size=(cols * width, rows * height))
# Paste each image into the grid
for i, img in enumerate(images):
x_pos = (i % cols) * width
y_pos = (i // cols) * height
grid.paste(img, box=(x_pos, y_pos))
return grid
# ===== CLIP Embedding Extraction Functions =====
@torch.inference_mode()
def extract_clip_embeddings_from_pil(pil_image: Union[Image.Image, List[Image.Image]],
ip_model) -> torch.Tensor:
"""
Returns:
torch.Tensor: CLIP embeddings of shape (batch_size, seq_len, embed_dim)
"""
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
# Process images through CLIP processor
processed_images = ip_model.clip_image_processor(
images=pil_image, return_tensors="pt"
).pixel_values
# Move to model device with appropriate dtype
processed_images = processed_images.to(
ip_model.device,
dtype=_get_ip_model_dtype(ip_model),
)
# Extract embeddings from penultimate layer (better for downstream tasks)
clip_embeddings = ip_model.image_encoder(
processed_images, output_hidden_states=True
).hidden_states[-2]
# Convert to float32 for better numerical stability
return clip_embeddings.float()
@torch.inference_mode()
def extract_clip_embeddings_from_pil_batch(pil_images: List[Image.Image],
ip_model) -> torch.Tensor:
"""
Returns:
torch.Tensor: Concatenated CLIP embeddings of shape (batch, seq_len, embed_dim)
"""
embeddings_batch = []
for image in pil_images:
embeddings = extract_clip_embeddings_from_pil(image, ip_model)
embeddings_batch.append(embeddings)
return torch.cat(embeddings_batch, dim=0)
@torch.inference_mode()
def extract_clip_embeddings_from_tensor(tensor_image: torch.Tensor,
ip_model,
resize: bool = True) -> torch.Tensor:
"""
Returns:
torch.Tensor: CLIP embeddings of shape (batch_size, seq_len, embed_dim)
"""
# Move tensor to model device with appropriate dtype
tensor_image = tensor_image.to(
ip_model.device,
dtype=_get_ip_model_dtype(ip_model),
)
# Resize to CLIP input resolution if requested
if resize:
tensor_image = torch.nn.functional.interpolate(
tensor_image,
size=(224, 224),
mode="bilinear",
align_corners=False
)
# Extract embeddings with positional encoding interpolation
clip_embeddings = ip_model.image_encoder(
tensor_image,
output_hidden_states=True,
interpolate_pos_encoding=True
).hidden_states[-2]
# Convert to float32 for numerical stability
return clip_embeddings.float()
# ===== IP-Adapter Helper Functions =====
@torch.inference_mode()
def _enhanced_get_image_embeds(self, pil_image=None, clip_image_embeds=None):
"""
Enhanced version of IP-Adapter's get_image_embeds method.
This method processes either PIL images or pre-computed CLIP embeddings
and returns both conditional and unconditional embeddings for generation.
Args:
pil_image: PIL Image(s) to process (optional)
clip_image_embeds: Pre-computed CLIP embeddings (optional)
Returns:
Tuple of (conditional_embeds, unconditional_embeds)
"""
# Process PIL images if provided
model_dtype = getattr(self, "torch_dtype", torch.float16)
if pil_image is not None:
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
# Convert PIL to tensor and extract CLIP embeddings
processed_images = self.clip_image_processor(
images=pil_image, return_tensors="pt"
).pixel_values
processed_images = processed_images.to(self.device, dtype=model_dtype)
clip_image_embeds = self.image_encoder(
processed_images, output_hidden_states=True
).hidden_states[-2]
else:
clip_image_embeds = clip_image_embeds.to(self.device, dtype=model_dtype)
# Project CLIP embeddings to IP-Adapter space
conditional_embeds = self.image_proj_model(clip_image_embeds)
# Generate unconditional embeddings (for classifier-free guidance)
zero_tensor = torch.zeros(
clip_image_embeds.shape[0],
3,
224,
224,
device=self.device,
dtype=model_dtype,
)
uncond_clip_embeds = self.image_encoder(
zero_tensor, output_hidden_states=True
).hidden_states[-2]
unconditional_embeds = self.image_proj_model(uncond_clip_embeds)
return conditional_embeds, unconditional_embeds
# ===== Model Loading Functions =====
@torch.inference_mode()
def load_stable_diffusion_pipeline(device: str = "cuda") -> StableDiffusionPipeline:
vae_model_path = "stabilityai/sd-vae-ft-mse"
torch_dtype = _get_sd15_torch_dtype(device)
# Configure DDIM scheduler for high-quality sampling
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
# Load VAE separately for better quality
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch_dtype)
# Create Stable Diffusion pipeline with fallback SD1.5 bases
pipeline = _load_sd15_base_pipeline(
noise_scheduler=noise_scheduler,
vae=vae,
torch_dtype=torch_dtype,
)
return pipeline
@torch.inference_mode()
def load_ip_adapter_model(
device: str = "cuda",
sd_only: bool = False,
) -> IPAdapterPlus | StableDiffusionPipeline:
# Model and checkpoint paths
vae_model_path = "stabilityai/sd-vae-ft-mse"
image_encoder_path = "./downloads/models/image_encoder"
ip_checkpoint_path = "./downloads/models/ip-adapter-plus_sd15.bin"
torch_dtype = _get_sd15_torch_dtype(device)
# Configure DDIM scheduler
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
# Load high-quality VAE
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch_dtype)
# Create base Stable Diffusion pipeline with fallback SD1.5 bases
pipeline = _load_sd15_base_pipeline(
noise_scheduler=noise_scheduler,
vae=vae,
torch_dtype=torch_dtype,
)
if sd_only:
return pipeline
# Initialize IP-Adapter with 16 tokens for better image conditioning
ip_model = IPAdapterPlus(
pipeline,
image_encoder_path,
ip_checkpoint_path,
device,
num_tokens=16
)
# Enhance the model with our improved get_image_embeds method
setattr(ip_model.__class__, "get_image_embeds", _enhanced_get_image_embeds)
return ip_model
def load_ip_adapter_xl_model(device: str = "cuda") -> IPAdapterPlusXL:
base_model_path = "SG161222/RealVisXL_V1.0"
image_encoder_path = "./downloads/models/image_encoder"
ip_ckpt = "./downloads/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin"
pipe = StableDiffusionXLPipeline.from_pretrained(
base_model_path,
torch_dtype=torch.float16,
add_watermarker=False,
)
ip_model = IPAdapterPlusXL(pipe, image_encoder_path, ip_ckpt, device, num_tokens=16)
return ip_model
def load_ipadapter(version: str = "sd15", device: Optional[str] = None) -> IPAdapterPlus | IPAdapterPlusXL:
device = device or _get_default_runtime_device()
if version == "sd15":
return load_ip_adapter_model(device)
elif version == "sdxl":
return load_ip_adapter_xl_model(device)
else:
raise ValueError(f"Invalid version: {version}")
# ===== Image Generation Functions =====
@torch.inference_mode()
def generate_images_from_clip_embeddings(ip_model : IPAdapterPlus,
clip_embeddings: torch.Tensor,
num_samples: int = 4,
num_inference_steps: int = 50,
seed: Optional[int] = 42) -> List[Image.Image]:
"""Generate images from CLIP embeddings using IP-Adapter.
clip_embeddings is (batch, seq_len, embed_dim)
"""
# Ensure embeddings have correct shape and dtype
if clip_embeddings.ndim == 2:
clip_embeddings = clip_embeddings.unsqueeze(0)
if clip_embeddings.ndim != 3:
raise ValueError(f"Expected 3D embeddings (batch, seq, dim), got {clip_embeddings.shape}")
# Move to appropriate device and dtype
clip_embeddings = clip_embeddings.to(
ip_model.device,
dtype=_get_ip_model_dtype(ip_model),
)
# Generate images using IP-Adapter
negative_prompt = "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]"
generated_images = ip_model.generate(
clip_image_embeds=clip_embeddings,
negative_prompt=negative_prompt,
pil_image=None,
num_samples=num_samples,
num_inference_steps=num_inference_steps,
seed=seed
)
return generated_images
# ===== Legacy Function Aliases =====
# Maintain backward compatibility with existing code
image_grid = create_image_grid
extract_clip_embedding_pil = extract_clip_embeddings_from_pil
extract_clip_embedding_pil_batch = extract_clip_embeddings_from_pil_batch
extract_clip_embedding_tensor = extract_clip_embeddings_from_tensor
load_sdxl = load_stable_diffusion_pipeline
generate = generate_images_from_clip_embeddings