gap-clip / evaluation /utils /embeddings.py
Leacb4's picture
Upload evaluation/utils/embeddings.py with huggingface_hub
4807234 verified
"""
Shared embedding extraction utilities for GAP-CLIP evaluation scripts.
Consolidates the batch embedding extraction logic that was duplicated across
sec51, sec52, sec533, and sec536 into two reusable functions:
- extract_clip_embeddings() — for any CLIP-based model (GAP-CLIP, Fashion-CLIP)
- extract_color_model_embeddings() — for the specialized 16D ColorCLIP model
"""
from __future__ import annotations
from typing import List, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _batch_tensors_to_pil(images: torch.Tensor) -> list:
"""Convert a batch of ImageNet-normalised tensors back to PIL images.
This is the shared denormalization logic that was duplicated in every
evaluator's image-embedding extraction method.
"""
pil_images = []
for i in range(images.shape[0]):
t = images[i]
if t.min() < 0 or t.max() > 1:
mean = torch.tensor([0.485, 0.456, 0.406], device=t.device).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225], device=t.device).view(3, 1, 1)
t = torch.clamp(t * std + mean, 0, 1)
pil_images.append(transforms.ToPILImage()(t.cpu()))
return pil_images
def _normalize_label(value: object, default: str = "unknown") -> str:
"""Convert label-like values to consistent non-empty strings."""
if value is None:
return default
# Handle pandas/NumPy missing values without importing pandas here.
try:
if bool(np.isnan(value)): # type: ignore[arg-type]
return default
except Exception:
pass
label = str(value).strip().lower()
if not label or label in {"none", "nan"}:
return default
return label.replace("grey", "gray")
# ---------------------------------------------------------------------------
# CLIP-based embedding extraction (GAP-CLIP or Fashion-CLIP)
# ---------------------------------------------------------------------------
def extract_clip_embeddings(
model,
processor,
dataloader: DataLoader,
device: torch.device,
embedding_type: str = "text",
max_samples: int = 10_000,
desc: str | None = None,
) -> Tuple[np.ndarray, List[str], List[str]]:
"""Extract L2-normalised embeddings from any CLIP-based model.
Works with both 3-element batches ``(image, text, color)`` and 4-element
batches ``(image, text, color, hierarchy)``. Always returns three lists
(embeddings, colors, hierarchies); when the batch has no hierarchy column
the third list is filled with ``"unknown"``.
Args:
model: A ``CLIPModel`` (GAP-CLIP, Fashion-CLIP, etc.).
processor: Matching ``CLIPProcessor``.
dataloader: PyTorch DataLoader yielding 3- or 4-element tuples.
device: Target torch device.
embedding_type: ``"text"`` or ``"image"``.
max_samples: Stop after collecting this many samples.
desc: Optional tqdm description override.
Returns:
``(embeddings, colors, hierarchies)`` where *embeddings* is an
``(N, D)`` numpy array and the other two are lists of strings.
"""
if desc is None:
desc = f"Extracting {embedding_type} embeddings"
all_embeddings: list[np.ndarray] = []
all_colors: list[str] = []
all_hierarchies: list[str] = []
sample_count = 0
with torch.no_grad():
for batch in tqdm(dataloader, desc=desc):
if sample_count >= max_samples:
break
# Support both 3-element and 4-element batch tuples
if len(batch) == 4:
images, texts, colors, hierarchies = batch
else:
images, texts, colors = batch
hierarchies = ["unknown"] * len(colors)
images = images.to(device).expand(-1, 3, -1, -1)
if embedding_type == "image":
pil_images = _batch_tensors_to_pil(images)
inputs = processor(images=pil_images, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
emb = model.get_image_features(**inputs)
else:
inputs = processor(
text=list(texts),
return_tensors="pt",
padding=True,
truncation=True,
max_length=77,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
emb = model.get_text_features(**inputs)
emb = F.normalize(emb, dim=-1)
all_embeddings.append(emb.cpu().numpy())
all_colors.extend(_normalize_label(c) for c in colors)
all_hierarchies.extend(_normalize_label(h) for h in hierarchies)
sample_count += len(images)
del images, emb
if torch.cuda.is_available():
torch.cuda.empty_cache()
return np.vstack(all_embeddings), all_colors, all_hierarchies
# ---------------------------------------------------------------------------
# Specialized ColorCLIP embedding extraction
# ---------------------------------------------------------------------------
def extract_color_model_embeddings(
color_model,
dataloader: DataLoader,
device: torch.device,
embedding_type: str = "text",
max_samples: int = 10_000,
desc: str | None = None,
) -> Tuple[np.ndarray, List[str]]:
"""Extract L2-normalised embeddings from the 16D ColorCLIP model.
Args:
color_model: A ``ColorCLIP`` instance.
dataloader: DataLoader yielding at least ``(image, text, color, ...)``.
device: Target torch device.
embedding_type: ``"text"`` or ``"image"``.
max_samples: Stop after collecting this many samples.
desc: Optional tqdm description override.
Returns:
``(embeddings, colors)`` — embeddings is ``(N, 16)`` numpy array.
"""
if desc is None:
desc = f"Extracting {embedding_type} color-model embeddings"
all_embeddings: list[np.ndarray] = []
all_colors: list[str] = []
sample_count = 0
with torch.no_grad():
for batch in tqdm(dataloader, desc=desc):
if sample_count >= max_samples:
break
images, texts, colors = batch[0], batch[1], batch[2]
images = images.to(device).expand(-1, 3, -1, -1)
if embedding_type == "text":
emb = color_model.get_text_embeddings(list(texts))
else:
emb = color_model.get_image_embeddings(images)
emb = F.normalize(emb, dim=-1)
all_embeddings.append(emb.cpu().numpy())
normalized_colors = [
str(c).lower().strip().replace("grey", "gray") for c in colors
]
all_colors.extend(normalized_colors)
sample_count += len(images)
return np.vstack(all_embeddings), all_colors