VibeSpace / extract_features.py
huzey's picture
Add CPU fallback for HF custom feature stages
22afff9
"""
Feature Extraction Module
This module provides utilities for extracting features from images using various
pre-trained models including DINO, DINOv3, and CLIP. It handles model loading,
batch processing, and memory management for efficient feature extraction.
"""
import gc
import logging
from typing import Tuple, Optional
import torch
import torch.nn as nn
from einops import rearrange
from torchvision import transforms
from ipadapter_model import extract_clip_embedding_tensor
from ipadapter_model import load_ipadapter
# Default hyperparameters
DEFAULT_BATCH_SIZE = 32
# ===== Image Transforms =====
# High-resolution transform for DINO models
dino_image_transform = transforms.Compose([
transforms.Resize((256 * 2, 256 * 2)), # High resolution for detailed features
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Standard resolution transform for CLIP models
clip_image_transform = transforms.Compose([
transforms.Resize((224, 224)), # Standard ImageNet resolution
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Inverse transform to convert normalized tensors back to PIL images
image_inverse_transform = transforms.Compose([
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1/0.229, 1/0.224, 1/0.225]),
transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1.0, 1.0, 1.0]),
transforms.ToPILImage(),
])
# ===== Memory Management =====
def clear_gpu_memory():
"""Clear GPU cache and run garbage collection to free memory."""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def _get_feature_device_candidates() -> list[str]:
"""Prefer CUDA when available, but allow CPU fallback for unsupported kernels."""
return ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"]
def _should_retry_on_cpu(exc: RuntimeError, device: str) -> bool:
if device != "cuda":
return False
error_message = str(exc).lower()
return "no kernel image is available for execution on the device" in error_message
# ===== Feature Extraction Functions =====
@torch.no_grad()
def extract_dino_features(images: torch.Tensor, batch_size: int = DEFAULT_BATCH_SIZE) -> torch.Tensor:
"""
Extract features using DINO ViT-S/16 model.
Args:
images (torch.Tensor): Input images of shape (N, C, H, W)
batch_size (int): Batch size for processing
Returns:
torch.Tensor: DINO features of shape (N, L, D)
"""
last_error: Optional[RuntimeError] = None
for device in _get_feature_device_candidates():
dino_model = None
try:
dino_model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
dino_model = dino_model.eval().to(device)
# Process images in batches
num_batches = (images.shape[0] + batch_size - 1) // batch_size
feature_batches = []
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, images.shape[0])
batch_images = images[start_idx:end_idx].to(device)
batch_features = dino_model.get_intermediate_layers(batch_images)[-1]
feature_batches.append(batch_features.cpu())
# Concatenate all batches
return torch.cat(feature_batches, dim=0)
except RuntimeError as exc:
last_error = exc
if _should_retry_on_cpu(exc, device):
logging.warning("DINO CUDA kernels are unavailable on this device; retrying feature extraction on CPU.")
continue
raise
finally:
if dino_model is not None:
del dino_model
clear_gpu_memory()
if last_error is not None:
raise last_error
raise RuntimeError("Failed to extract DINO features.")
@torch.no_grad()
def extract_clip_features(images: torch.Tensor, batch_size: int = DEFAULT_BATCH_SIZE, ipadapter_version: str = "sd15") -> torch.Tensor:
"""
Extract features using CLIP vision encoder.
Args:
images (torch.Tensor): Input images of shape (N, C, H, W)
batch_size (int): Batch size for processing
Returns:
torch.Tensor: CLIP features of shape (N, L, D)
"""
last_error: Optional[RuntimeError] = None
for device in _get_feature_device_candidates():
ip_adapter_model = None
try:
# Load IP-Adapter model (contains CLIP encoder)
ip_adapter_model = load_ipadapter(version=ipadapter_version, device=device)
# Process images in batches
num_batches = (images.shape[0] + batch_size - 1) // batch_size
feature_batches = []
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, images.shape[0])
batch_images = images[start_idx:end_idx].to(device)
batch_features = extract_clip_embedding_tensor(
batch_images, ip_adapter_model, resize=False
)
feature_batches.append(batch_features.cpu())
# Concatenate all batches
return torch.cat(feature_batches, dim=0)
except RuntimeError as exc:
last_error = exc
if _should_retry_on_cpu(exc, device):
logging.warning("CLIP CUDA kernels are unavailable on this device; retrying feature extraction on CPU.")
continue
raise
finally:
if ip_adapter_model is not None:
del ip_adapter_model
clear_gpu_memory()
if last_error is not None:
raise last_error
raise RuntimeError("Failed to extract CLIP features.")