| """ |
| Shared embedding extraction utilities for GAP-CLIP evaluation scripts. |
| |
| Consolidates the batch embedding extraction logic that was duplicated across |
| sec51, sec52, sec533, and sec536 into two reusable functions: |
| |
| - extract_clip_embeddings() — for any CLIP-based model (GAP-CLIP, Fashion-CLIP) |
| - extract_color_model_embeddings() — for the specialized 16D ColorCLIP model |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import List, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
| from torchvision import transforms |
| from tqdm import tqdm |
|
|
|
|
| |
| |
| |
|
|
| def _batch_tensors_to_pil(images: torch.Tensor) -> list: |
| """Convert a batch of ImageNet-normalised tensors back to PIL images. |
| |
| This is the shared denormalization logic that was duplicated in every |
| evaluator's image-embedding extraction method. |
| """ |
| pil_images = [] |
| for i in range(images.shape[0]): |
| t = images[i] |
| if t.min() < 0 or t.max() > 1: |
| mean = torch.tensor([0.485, 0.456, 0.406], device=t.device).view(3, 1, 1) |
| std = torch.tensor([0.229, 0.224, 0.225], device=t.device).view(3, 1, 1) |
| t = torch.clamp(t * std + mean, 0, 1) |
| pil_images.append(transforms.ToPILImage()(t.cpu())) |
| return pil_images |
|
|
|
|
| def _normalize_label(value: object, default: str = "unknown") -> str: |
| """Convert label-like values to consistent non-empty strings.""" |
| if value is None: |
| return default |
|
|
| |
| try: |
| if bool(np.isnan(value)): |
| return default |
| except Exception: |
| pass |
|
|
| label = str(value).strip().lower() |
| if not label or label in {"none", "nan"}: |
| return default |
| return label.replace("grey", "gray") |
|
|
|
|
| |
| |
| |
|
|
| def extract_clip_embeddings( |
| model, |
| processor, |
| dataloader: DataLoader, |
| device: torch.device, |
| embedding_type: str = "text", |
| max_samples: int = 10_000, |
| desc: str | None = None, |
| ) -> Tuple[np.ndarray, List[str], List[str]]: |
| """Extract L2-normalised embeddings from any CLIP-based model. |
| |
| Works with both 3-element batches ``(image, text, color)`` and 4-element |
| batches ``(image, text, color, hierarchy)``. Always returns three lists |
| (embeddings, colors, hierarchies); when the batch has no hierarchy column |
| the third list is filled with ``"unknown"``. |
| |
| Args: |
| model: A ``CLIPModel`` (GAP-CLIP, Fashion-CLIP, etc.). |
| processor: Matching ``CLIPProcessor``. |
| dataloader: PyTorch DataLoader yielding 3- or 4-element tuples. |
| device: Target torch device. |
| embedding_type: ``"text"`` or ``"image"``. |
| max_samples: Stop after collecting this many samples. |
| desc: Optional tqdm description override. |
| |
| Returns: |
| ``(embeddings, colors, hierarchies)`` where *embeddings* is an |
| ``(N, D)`` numpy array and the other two are lists of strings. |
| """ |
| if desc is None: |
| desc = f"Extracting {embedding_type} embeddings" |
|
|
| all_embeddings: list[np.ndarray] = [] |
| all_colors: list[str] = [] |
| all_hierarchies: list[str] = [] |
| sample_count = 0 |
|
|
| with torch.no_grad(): |
| for batch in tqdm(dataloader, desc=desc): |
| if sample_count >= max_samples: |
| break |
|
|
| |
| if len(batch) == 4: |
| images, texts, colors, hierarchies = batch |
| else: |
| images, texts, colors = batch |
| hierarchies = ["unknown"] * len(colors) |
|
|
| images = images.to(device).expand(-1, 3, -1, -1) |
|
|
| if embedding_type == "image": |
| pil_images = _batch_tensors_to_pil(images) |
| inputs = processor(images=pil_images, return_tensors="pt") |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| emb = model.get_image_features(**inputs) |
| else: |
| inputs = processor( |
| text=list(texts), |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=77, |
| ) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| emb = model.get_text_features(**inputs) |
|
|
| emb = F.normalize(emb, dim=-1) |
|
|
| all_embeddings.append(emb.cpu().numpy()) |
| all_colors.extend(_normalize_label(c) for c in colors) |
| all_hierarchies.extend(_normalize_label(h) for h in hierarchies) |
| sample_count += len(images) |
|
|
| del images, emb |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| return np.vstack(all_embeddings), all_colors, all_hierarchies |
|
|
|
|
| |
| |
| |
|
|
| def extract_color_model_embeddings( |
| color_model, |
| dataloader: DataLoader, |
| device: torch.device, |
| embedding_type: str = "text", |
| max_samples: int = 10_000, |
| desc: str | None = None, |
| ) -> Tuple[np.ndarray, List[str]]: |
| """Extract L2-normalised embeddings from the 16D ColorCLIP model. |
| |
| Args: |
| color_model: A ``ColorCLIP`` instance. |
| dataloader: DataLoader yielding at least ``(image, text, color, ...)``. |
| device: Target torch device. |
| embedding_type: ``"text"`` or ``"image"``. |
| max_samples: Stop after collecting this many samples. |
| desc: Optional tqdm description override. |
| |
| Returns: |
| ``(embeddings, colors)`` — embeddings is ``(N, 16)`` numpy array. |
| """ |
| if desc is None: |
| desc = f"Extracting {embedding_type} color-model embeddings" |
|
|
| all_embeddings: list[np.ndarray] = [] |
| all_colors: list[str] = [] |
| sample_count = 0 |
|
|
| with torch.no_grad(): |
| for batch in tqdm(dataloader, desc=desc): |
| if sample_count >= max_samples: |
| break |
|
|
| images, texts, colors = batch[0], batch[1], batch[2] |
| images = images.to(device).expand(-1, 3, -1, -1) |
|
|
| if embedding_type == "text": |
| emb = color_model.get_text_embeddings(list(texts)) |
| else: |
| emb = color_model.get_image_embeddings(images) |
| emb = F.normalize(emb, dim=-1) |
|
|
| all_embeddings.append(emb.cpu().numpy()) |
| normalized_colors = [ |
| str(c).lower().strip().replace("grey", "gray") for c in colors |
| ] |
| all_colors.extend(normalized_colors) |
| sample_count += len(images) |
|
|
| return np.vstack(all_embeddings), all_colors |
|
|