""" Feature Extraction Module This module provides utilities for extracting features from images using various pre-trained models including DINO, DINOv3, and CLIP. It handles model loading, batch processing, and memory management for efficient feature extraction. """ import gc import logging from typing import Tuple, Optional import torch import torch.nn as nn from einops import rearrange from torchvision import transforms from ipadapter_model import extract_clip_embedding_tensor from ipadapter_model import load_ipadapter # Default hyperparameters DEFAULT_BATCH_SIZE = 32 # ===== Image Transforms ===== # High-resolution transform for DINO models dino_image_transform = transforms.Compose([ transforms.Resize((256 * 2, 256 * 2)), # High resolution for detailed features transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Standard resolution transform for CLIP models clip_image_transform = transforms.Compose([ transforms.Resize((224, 224)), # Standard ImageNet resolution transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Inverse transform to convert normalized tensors back to PIL images image_inverse_transform = transforms.Compose([ transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1/0.229, 1/0.224, 1/0.225]), transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1.0, 1.0, 1.0]), transforms.ToPILImage(), ]) # ===== Memory Management ===== def clear_gpu_memory(): """Clear GPU cache and run garbage collection to free memory.""" if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() def _get_feature_device_candidates() -> list[str]: """Prefer CUDA when available, but allow CPU fallback for unsupported kernels.""" return ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"] def _should_retry_on_cpu(exc: RuntimeError, device: str) -> bool: if device != "cuda": return False error_message = str(exc).lower() return "no kernel image is available for execution on the device" in error_message # ===== Feature Extraction Functions ===== @torch.no_grad() def extract_dino_features(images: torch.Tensor, batch_size: int = DEFAULT_BATCH_SIZE) -> torch.Tensor: """ Extract features using DINO ViT-S/16 model. Args: images (torch.Tensor): Input images of shape (N, C, H, W) batch_size (int): Batch size for processing Returns: torch.Tensor: DINO features of shape (N, L, D) """ last_error: Optional[RuntimeError] = None for device in _get_feature_device_candidates(): dino_model = None try: dino_model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16') dino_model = dino_model.eval().to(device) # Process images in batches num_batches = (images.shape[0] + batch_size - 1) // batch_size feature_batches = [] for batch_idx in range(num_batches): start_idx = batch_idx * batch_size end_idx = min((batch_idx + 1) * batch_size, images.shape[0]) batch_images = images[start_idx:end_idx].to(device) batch_features = dino_model.get_intermediate_layers(batch_images)[-1] feature_batches.append(batch_features.cpu()) # Concatenate all batches return torch.cat(feature_batches, dim=0) except RuntimeError as exc: last_error = exc if _should_retry_on_cpu(exc, device): logging.warning("DINO CUDA kernels are unavailable on this device; retrying feature extraction on CPU.") continue raise finally: if dino_model is not None: del dino_model clear_gpu_memory() if last_error is not None: raise last_error raise RuntimeError("Failed to extract DINO features.") @torch.no_grad() def extract_clip_features(images: torch.Tensor, batch_size: int = DEFAULT_BATCH_SIZE, ipadapter_version: str = "sd15") -> torch.Tensor: """ Extract features using CLIP vision encoder. Args: images (torch.Tensor): Input images of shape (N, C, H, W) batch_size (int): Batch size for processing Returns: torch.Tensor: CLIP features of shape (N, L, D) """ last_error: Optional[RuntimeError] = None for device in _get_feature_device_candidates(): ip_adapter_model = None try: # Load IP-Adapter model (contains CLIP encoder) ip_adapter_model = load_ipadapter(version=ipadapter_version, device=device) # Process images in batches num_batches = (images.shape[0] + batch_size - 1) // batch_size feature_batches = [] for batch_idx in range(num_batches): start_idx = batch_idx * batch_size end_idx = min((batch_idx + 1) * batch_size, images.shape[0]) batch_images = images[start_idx:end_idx].to(device) batch_features = extract_clip_embedding_tensor( batch_images, ip_adapter_model, resize=False ) feature_batches.append(batch_features.cpu()) # Concatenate all batches return torch.cat(feature_batches, dim=0) except RuntimeError as exc: last_error = exc if _should_retry_on_cpu(exc, device): logging.warning("CLIP CUDA kernels are unavailable on this device; retrying feature extraction on CPU.") continue raise finally: if ip_adapter_model is not None: del ip_adapter_model clear_gpu_memory() if last_error is not None: raise last_error raise RuntimeError("Failed to extract CLIP features.")