Spaces:
Sleeping
Sleeping
| """ | |
| IP-Adapter Model Interface | |
| This module provides utilities for working with IP-Adapter models, including: | |
| - Loading Stable Diffusion pipelines with IP-Adapter | |
| - Extracting CLIP embeddings from images | |
| - Generating images from CLIP embeddings | |
| - Utility functions for image processing | |
| """ | |
| import logging | |
| import os | |
| from typing import List, Optional, Union, Tuple | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, DDIMScheduler, AutoencoderKL | |
| # Fix for torch 2.5.0 compatibility | |
| torch.backends.cuda.enable_cudnn_sdp(False) | |
| from ip_adapter import IPAdapterPlus, IPAdapterPlusXL | |
| # ===== SD1.5 Base Model Resolution ===== | |
| DEFAULT_SD15_BASE_MODEL = "SG161222/Realistic_Vision_V4.0_noVAE" | |
| DEFAULT_SD15_FALLBACK_MODELS = ( | |
| "runwayml/stable-diffusion-v1-5", | |
| "stable-diffusion-v1-5/stable-diffusion-v1-5", | |
| ) | |
| SD15_BASE_MODEL_ENV = "VIBESPACE_SD15_BASE_MODEL" | |
| SD15_FALLBACK_MODELS_ENV = "VIBESPACE_SD15_FALLBACK_MODELS" | |
| def _get_sd15_torch_dtype(device: str) -> torch.dtype: | |
| """Keep SD1.5 in fp16 on CUDA and use fp32 elsewhere for numerical stability.""" | |
| return torch.float16 if str(device).lower().startswith("cuda") else torch.float32 | |
| def _get_default_runtime_device() -> str: | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| if torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| def _get_ip_model_dtype(ip_model) -> torch.dtype: | |
| return getattr(ip_model, "torch_dtype", torch.float16) | |
| def _get_sd15_base_model_candidates() -> List[str]: | |
| """Build an ordered, de-duplicated list of SD1.5 base model candidates.""" | |
| configured_primary = os.getenv(SD15_BASE_MODEL_ENV, "").strip() | |
| configured_fallbacks = os.getenv(SD15_FALLBACK_MODELS_ENV, "").strip() | |
| candidates: List[str] = [] | |
| if configured_primary: | |
| candidates.append(configured_primary) | |
| candidates.append(DEFAULT_SD15_BASE_MODEL) | |
| if configured_fallbacks: | |
| candidates.extend( | |
| model_id.strip() | |
| for model_id in configured_fallbacks.split(",") | |
| if model_id.strip() | |
| ) | |
| else: | |
| candidates.extend(DEFAULT_SD15_FALLBACK_MODELS) | |
| deduplicated_candidates: List[str] = [] | |
| seen = set() | |
| for model_id in candidates: | |
| if model_id in seen: | |
| continue | |
| deduplicated_candidates.append(model_id) | |
| seen.add(model_id) | |
| return deduplicated_candidates | |
| def _load_sd15_base_pipeline( | |
| noise_scheduler: DDIMScheduler, | |
| vae: AutoencoderKL, | |
| torch_dtype: torch.dtype, | |
| ) -> StableDiffusionPipeline: | |
| """Load the first available SD1.5-compatible base pipeline.""" | |
| candidates = _get_sd15_base_model_candidates() | |
| last_error: Optional[Exception] = None | |
| for index, base_model_path in enumerate(candidates): | |
| try: | |
| return StableDiffusionPipeline.from_pretrained( | |
| base_model_path, | |
| torch_dtype=torch_dtype, | |
| scheduler=noise_scheduler, | |
| vae=vae, | |
| feature_extractor=None, | |
| safety_checker=None, | |
| ) | |
| except Exception as exc: # noqa: BLE001 | |
| last_error = exc | |
| if index < len(candidates) - 1: | |
| logging.warning( | |
| "Failed to load SD1.5 base model '%s'; trying fallback. Error: %s", | |
| base_model_path, | |
| exc, | |
| ) | |
| candidate_list = ", ".join(candidates) | |
| raise RuntimeError( | |
| f"Failed to load any SD1.5 base model. Tried: {candidate_list}" | |
| ) from last_error | |
| # ===== Image Utility Functions ===== | |
| def create_image_grid(images: List[Image.Image], rows: int, cols: int) -> Image.Image: | |
| # Get dimensions from first image (assumes all images are same size) | |
| width, height = images[0].size | |
| # Create empty grid canvas | |
| grid = Image.new('RGB', size=(cols * width, rows * height)) | |
| # Paste each image into the grid | |
| for i, img in enumerate(images): | |
| x_pos = (i % cols) * width | |
| y_pos = (i // cols) * height | |
| grid.paste(img, box=(x_pos, y_pos)) | |
| return grid | |
| # ===== CLIP Embedding Extraction Functions ===== | |
| def extract_clip_embeddings_from_pil(pil_image: Union[Image.Image, List[Image.Image]], | |
| ip_model) -> torch.Tensor: | |
| """ | |
| Returns: | |
| torch.Tensor: CLIP embeddings of shape (batch_size, seq_len, embed_dim) | |
| """ | |
| if isinstance(pil_image, Image.Image): | |
| pil_image = [pil_image] | |
| # Process images through CLIP processor | |
| processed_images = ip_model.clip_image_processor( | |
| images=pil_image, return_tensors="pt" | |
| ).pixel_values | |
| # Move to model device with appropriate dtype | |
| processed_images = processed_images.to( | |
| ip_model.device, | |
| dtype=_get_ip_model_dtype(ip_model), | |
| ) | |
| # Extract embeddings from penultimate layer (better for downstream tasks) | |
| clip_embeddings = ip_model.image_encoder( | |
| processed_images, output_hidden_states=True | |
| ).hidden_states[-2] | |
| # Convert to float32 for better numerical stability | |
| return clip_embeddings.float() | |
| def extract_clip_embeddings_from_pil_batch(pil_images: List[Image.Image], | |
| ip_model) -> torch.Tensor: | |
| """ | |
| Returns: | |
| torch.Tensor: Concatenated CLIP embeddings of shape (batch, seq_len, embed_dim) | |
| """ | |
| embeddings_batch = [] | |
| for image in pil_images: | |
| embeddings = extract_clip_embeddings_from_pil(image, ip_model) | |
| embeddings_batch.append(embeddings) | |
| return torch.cat(embeddings_batch, dim=0) | |
| def extract_clip_embeddings_from_tensor(tensor_image: torch.Tensor, | |
| ip_model, | |
| resize: bool = True) -> torch.Tensor: | |
| """ | |
| Returns: | |
| torch.Tensor: CLIP embeddings of shape (batch_size, seq_len, embed_dim) | |
| """ | |
| # Move tensor to model device with appropriate dtype | |
| tensor_image = tensor_image.to( | |
| ip_model.device, | |
| dtype=_get_ip_model_dtype(ip_model), | |
| ) | |
| # Resize to CLIP input resolution if requested | |
| if resize: | |
| tensor_image = torch.nn.functional.interpolate( | |
| tensor_image, | |
| size=(224, 224), | |
| mode="bilinear", | |
| align_corners=False | |
| ) | |
| # Extract embeddings with positional encoding interpolation | |
| clip_embeddings = ip_model.image_encoder( | |
| tensor_image, | |
| output_hidden_states=True, | |
| interpolate_pos_encoding=True | |
| ).hidden_states[-2] | |
| # Convert to float32 for numerical stability | |
| return clip_embeddings.float() | |
| # ===== IP-Adapter Helper Functions ===== | |
| def _enhanced_get_image_embeds(self, pil_image=None, clip_image_embeds=None): | |
| """ | |
| Enhanced version of IP-Adapter's get_image_embeds method. | |
| This method processes either PIL images or pre-computed CLIP embeddings | |
| and returns both conditional and unconditional embeddings for generation. | |
| Args: | |
| pil_image: PIL Image(s) to process (optional) | |
| clip_image_embeds: Pre-computed CLIP embeddings (optional) | |
| Returns: | |
| Tuple of (conditional_embeds, unconditional_embeds) | |
| """ | |
| # Process PIL images if provided | |
| model_dtype = getattr(self, "torch_dtype", torch.float16) | |
| if pil_image is not None: | |
| if isinstance(pil_image, Image.Image): | |
| pil_image = [pil_image] | |
| # Convert PIL to tensor and extract CLIP embeddings | |
| processed_images = self.clip_image_processor( | |
| images=pil_image, return_tensors="pt" | |
| ).pixel_values | |
| processed_images = processed_images.to(self.device, dtype=model_dtype) | |
| clip_image_embeds = self.image_encoder( | |
| processed_images, output_hidden_states=True | |
| ).hidden_states[-2] | |
| else: | |
| clip_image_embeds = clip_image_embeds.to(self.device, dtype=model_dtype) | |
| # Project CLIP embeddings to IP-Adapter space | |
| conditional_embeds = self.image_proj_model(clip_image_embeds) | |
| # Generate unconditional embeddings (for classifier-free guidance) | |
| zero_tensor = torch.zeros( | |
| clip_image_embeds.shape[0], | |
| 3, | |
| 224, | |
| 224, | |
| device=self.device, | |
| dtype=model_dtype, | |
| ) | |
| uncond_clip_embeds = self.image_encoder( | |
| zero_tensor, output_hidden_states=True | |
| ).hidden_states[-2] | |
| unconditional_embeds = self.image_proj_model(uncond_clip_embeds) | |
| return conditional_embeds, unconditional_embeds | |
| # ===== Model Loading Functions ===== | |
| def load_stable_diffusion_pipeline(device: str = "cuda") -> StableDiffusionPipeline: | |
| vae_model_path = "stabilityai/sd-vae-ft-mse" | |
| torch_dtype = _get_sd15_torch_dtype(device) | |
| # Configure DDIM scheduler for high-quality sampling | |
| noise_scheduler = DDIMScheduler( | |
| num_train_timesteps=1000, | |
| beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| clip_sample=False, | |
| set_alpha_to_one=False, | |
| steps_offset=1, | |
| ) | |
| # Load VAE separately for better quality | |
| vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch_dtype) | |
| # Create Stable Diffusion pipeline with fallback SD1.5 bases | |
| pipeline = _load_sd15_base_pipeline( | |
| noise_scheduler=noise_scheduler, | |
| vae=vae, | |
| torch_dtype=torch_dtype, | |
| ) | |
| return pipeline | |
| def load_ip_adapter_model( | |
| device: str = "cuda", | |
| sd_only: bool = False, | |
| ) -> IPAdapterPlus | StableDiffusionPipeline: | |
| # Model and checkpoint paths | |
| vae_model_path = "stabilityai/sd-vae-ft-mse" | |
| image_encoder_path = "./downloads/models/image_encoder" | |
| ip_checkpoint_path = "./downloads/models/ip-adapter-plus_sd15.bin" | |
| torch_dtype = _get_sd15_torch_dtype(device) | |
| # Configure DDIM scheduler | |
| noise_scheduler = DDIMScheduler( | |
| num_train_timesteps=1000, | |
| beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| clip_sample=False, | |
| set_alpha_to_one=False, | |
| steps_offset=1, | |
| ) | |
| # Load high-quality VAE | |
| vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch_dtype) | |
| # Create base Stable Diffusion pipeline with fallback SD1.5 bases | |
| pipeline = _load_sd15_base_pipeline( | |
| noise_scheduler=noise_scheduler, | |
| vae=vae, | |
| torch_dtype=torch_dtype, | |
| ) | |
| if sd_only: | |
| return pipeline | |
| # Initialize IP-Adapter with 16 tokens for better image conditioning | |
| ip_model = IPAdapterPlus( | |
| pipeline, | |
| image_encoder_path, | |
| ip_checkpoint_path, | |
| device, | |
| num_tokens=16 | |
| ) | |
| # Enhance the model with our improved get_image_embeds method | |
| setattr(ip_model.__class__, "get_image_embeds", _enhanced_get_image_embeds) | |
| return ip_model | |
| def load_ip_adapter_xl_model(device: str = "cuda") -> IPAdapterPlusXL: | |
| base_model_path = "SG161222/RealVisXL_V1.0" | |
| image_encoder_path = "./downloads/models/image_encoder" | |
| ip_ckpt = "./downloads/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin" | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| base_model_path, | |
| torch_dtype=torch.float16, | |
| add_watermarker=False, | |
| ) | |
| ip_model = IPAdapterPlusXL(pipe, image_encoder_path, ip_ckpt, device, num_tokens=16) | |
| return ip_model | |
| def load_ipadapter(version: str = "sd15", device: Optional[str] = None) -> IPAdapterPlus | IPAdapterPlusXL: | |
| device = device or _get_default_runtime_device() | |
| if version == "sd15": | |
| return load_ip_adapter_model(device) | |
| elif version == "sdxl": | |
| return load_ip_adapter_xl_model(device) | |
| else: | |
| raise ValueError(f"Invalid version: {version}") | |
| # ===== Image Generation Functions ===== | |
| def generate_images_from_clip_embeddings(ip_model : IPAdapterPlus, | |
| clip_embeddings: torch.Tensor, | |
| num_samples: int = 4, | |
| num_inference_steps: int = 50, | |
| seed: Optional[int] = 42) -> List[Image.Image]: | |
| """Generate images from CLIP embeddings using IP-Adapter. | |
| clip_embeddings is (batch, seq_len, embed_dim) | |
| """ | |
| # Ensure embeddings have correct shape and dtype | |
| if clip_embeddings.ndim == 2: | |
| clip_embeddings = clip_embeddings.unsqueeze(0) | |
| if clip_embeddings.ndim != 3: | |
| raise ValueError(f"Expected 3D embeddings (batch, seq, dim), got {clip_embeddings.shape}") | |
| # Move to appropriate device and dtype | |
| clip_embeddings = clip_embeddings.to( | |
| ip_model.device, | |
| dtype=_get_ip_model_dtype(ip_model), | |
| ) | |
| # Generate images using IP-Adapter | |
| negative_prompt = "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]" | |
| generated_images = ip_model.generate( | |
| clip_image_embeds=clip_embeddings, | |
| negative_prompt=negative_prompt, | |
| pil_image=None, | |
| num_samples=num_samples, | |
| num_inference_steps=num_inference_steps, | |
| seed=seed | |
| ) | |
| return generated_images | |
| # ===== Legacy Function Aliases ===== | |
| # Maintain backward compatibility with existing code | |
| image_grid = create_image_grid | |
| extract_clip_embedding_pil = extract_clip_embeddings_from_pil | |
| extract_clip_embedding_pil_batch = extract_clip_embeddings_from_pil_batch | |
| extract_clip_embedding_tensor = extract_clip_embeddings_from_tensor | |
| load_sdxl = load_stable_diffusion_pipeline | |
| generate = generate_images_from_clip_embeddings | |