Spaces:
Sleeping
Sleeping
| """ | |
| Feature Extraction Module | |
| This module provides utilities for extracting features from images using various | |
| pre-trained models including DINO, DINOv3, and CLIP. It handles model loading, | |
| batch processing, and memory management for efficient feature extraction. | |
| """ | |
| import gc | |
| import logging | |
| from typing import Tuple, Optional | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| from torchvision import transforms | |
| from ipadapter_model import extract_clip_embedding_tensor | |
| from ipadapter_model import load_ipadapter | |
| # Default hyperparameters | |
| DEFAULT_BATCH_SIZE = 32 | |
| # ===== Image Transforms ===== | |
| # High-resolution transform for DINO models | |
| dino_image_transform = transforms.Compose([ | |
| transforms.Resize((256 * 2, 256 * 2)), # High resolution for detailed features | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # Standard resolution transform for CLIP models | |
| clip_image_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), # Standard ImageNet resolution | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # Inverse transform to convert normalized tensors back to PIL images | |
| image_inverse_transform = transforms.Compose([ | |
| transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1/0.229, 1/0.224, 1/0.225]), | |
| transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1.0, 1.0, 1.0]), | |
| transforms.ToPILImage(), | |
| ]) | |
| # ===== Memory Management ===== | |
| def clear_gpu_memory(): | |
| """Clear GPU cache and run garbage collection to free memory.""" | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def _get_feature_device_candidates() -> list[str]: | |
| """Prefer CUDA when available, but allow CPU fallback for unsupported kernels.""" | |
| return ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"] | |
| def _should_retry_on_cpu(exc: RuntimeError, device: str) -> bool: | |
| if device != "cuda": | |
| return False | |
| error_message = str(exc).lower() | |
| return "no kernel image is available for execution on the device" in error_message | |
| # ===== Feature Extraction Functions ===== | |
| def extract_dino_features(images: torch.Tensor, batch_size: int = DEFAULT_BATCH_SIZE) -> torch.Tensor: | |
| """ | |
| Extract features using DINO ViT-S/16 model. | |
| Args: | |
| images (torch.Tensor): Input images of shape (N, C, H, W) | |
| batch_size (int): Batch size for processing | |
| Returns: | |
| torch.Tensor: DINO features of shape (N, L, D) | |
| """ | |
| last_error: Optional[RuntimeError] = None | |
| for device in _get_feature_device_candidates(): | |
| dino_model = None | |
| try: | |
| dino_model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16') | |
| dino_model = dino_model.eval().to(device) | |
| # Process images in batches | |
| num_batches = (images.shape[0] + batch_size - 1) // batch_size | |
| feature_batches = [] | |
| for batch_idx in range(num_batches): | |
| start_idx = batch_idx * batch_size | |
| end_idx = min((batch_idx + 1) * batch_size, images.shape[0]) | |
| batch_images = images[start_idx:end_idx].to(device) | |
| batch_features = dino_model.get_intermediate_layers(batch_images)[-1] | |
| feature_batches.append(batch_features.cpu()) | |
| # Concatenate all batches | |
| return torch.cat(feature_batches, dim=0) | |
| except RuntimeError as exc: | |
| last_error = exc | |
| if _should_retry_on_cpu(exc, device): | |
| logging.warning("DINO CUDA kernels are unavailable on this device; retrying feature extraction on CPU.") | |
| continue | |
| raise | |
| finally: | |
| if dino_model is not None: | |
| del dino_model | |
| clear_gpu_memory() | |
| if last_error is not None: | |
| raise last_error | |
| raise RuntimeError("Failed to extract DINO features.") | |
| def extract_clip_features(images: torch.Tensor, batch_size: int = DEFAULT_BATCH_SIZE, ipadapter_version: str = "sd15") -> torch.Tensor: | |
| """ | |
| Extract features using CLIP vision encoder. | |
| Args: | |
| images (torch.Tensor): Input images of shape (N, C, H, W) | |
| batch_size (int): Batch size for processing | |
| Returns: | |
| torch.Tensor: CLIP features of shape (N, L, D) | |
| """ | |
| last_error: Optional[RuntimeError] = None | |
| for device in _get_feature_device_candidates(): | |
| ip_adapter_model = None | |
| try: | |
| # Load IP-Adapter model (contains CLIP encoder) | |
| ip_adapter_model = load_ipadapter(version=ipadapter_version, device=device) | |
| # Process images in batches | |
| num_batches = (images.shape[0] + batch_size - 1) // batch_size | |
| feature_batches = [] | |
| for batch_idx in range(num_batches): | |
| start_idx = batch_idx * batch_size | |
| end_idx = min((batch_idx + 1) * batch_size, images.shape[0]) | |
| batch_images = images[start_idx:end_idx].to(device) | |
| batch_features = extract_clip_embedding_tensor( | |
| batch_images, ip_adapter_model, resize=False | |
| ) | |
| feature_batches.append(batch_features.cpu()) | |
| # Concatenate all batches | |
| return torch.cat(feature_batches, dim=0) | |
| except RuntimeError as exc: | |
| last_error = exc | |
| if _should_retry_on_cpu(exc, device): | |
| logging.warning("CLIP CUDA kernels are unavailable on this device; retrying feature extraction on CPU.") | |
| continue | |
| raise | |
| finally: | |
| if ip_adapter_model is not None: | |
| del ip_adapter_model | |
| clear_gpu_memory() | |
| if last_error is not None: | |
| raise last_error | |
| raise RuntimeError("Failed to extract CLIP features.") | |