| |
| """ |
| Section 5.3.6 — Embedding Structure Evaluation |
| =============================================== |
| |
| Verifies that the GAP-CLIP embedding subspaces encode the attributes they are |
| designed for, and tests zero-shot vision-language alignment. |
| |
| Test A — Different colors, same hierarchy: |
| The 64D hierarchy subspace should be MORE similar between two items that |
| share a category but differ in color, compared to the 16D color subspace. |
| Expected result: 1000/1000 pass. |
| Example: |
| In Test A, the code computes for each pair: |
| - sim_hier = cosine between the hierarchy slice (emb[16:80]) |
| - sim_full512 = cosine between the full 512-d embedding (emb) |
| The test check: |
| - pair_ok = (sim_hier > sim_color) and (sim_hier > sim_full512) |
| Test B — Same color, different hierarchies: |
| The 16D color subspace should be MORE similar than the full 512D embedding |
| for items sharing a color but differing in category. |
| Expected result: 1000/1000 pass. |
| |
| Test C — Subspace Decomposition Consistency: |
| Encode a full description (e.g. "red dress in cotton"), a standalone color |
| ("red"), and a standalone hierarchy ("dress"). Verify that: |
| - The color subspace (first 16D) of the full embedding is more similar |
| to the color-only embedding than to the hierarchy-only embedding. |
| - The hierarchy subspace (dims 16-80) of the full embedding is more |
| similar to the hierarchy-only embedding than to the color-only embedding. |
| Expected result: 1000/1000 pass. |
| |
| Test D — Zero-shot image-to-text classification: |
| Each image is used as a query; the highest-scoring text label (cosine in |
| shared latent space) is the predicted class. Accuracy is computed across |
| three datasets (Fashion-MNIST, KAGL Marqo, Internal). |
| |
| Paper reference: Section 5.3.6 and Table 4. |
| |
| Run directly: |
| python sec536_embedding_structure.py --tests AB # only tests A+B |
| python sec536_embedding_structure.py --tests ABCD # all tests |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import os |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| from dataclasses import dataclass |
| from pathlib import Path |
| import random |
| import sys |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| from typing import Dict, List, Optional, Sequence, Tuple |
|
|
| import numpy as np |
| import pandas as pd |
| import requests |
| from sklearn.metrics import f1_score |
| import torch |
| import torch.nn.functional as F |
| from io import BytesIO |
| from PIL import Image |
| from torchvision import transforms |
| from torchvision import datasets |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| from transformers import CLIPModel as CLIPModelTransformers |
| from transformers import CLIPProcessor |
|
|
| from training.hierarchy_model import HierarchyExtractor |
|
|
| try: |
| import config as project_config |
| except Exception: |
| project_config = None |
|
|
| DEFAULT_COLOR_EMB_DIM = getattr(project_config, "color_emb_dim", 16) |
| DEFAULT_HIERARCHY_EMB_DIM = getattr(project_config, "hierarchy_emb_dim", 64) |
| DEFAULT_MAIN_EMB_DIM = getattr(project_config, "main_emb_dim", 512) |
| DEFAULT_MAIN_MODEL_PATH = getattr(project_config, "main_model_path", "models/gap_clip.pth") |
| DEFAULT_DEVICE = getattr(project_config, "device", torch.device("cpu")) |
|
|
| _HIERARCHY_EXTRACTOR = HierarchyExtractor([ |
| "accessories", "bodysuits", "bras", "coat", "dress", "jacket", |
| "legging", "pant", "polo", "shirt", "shoes", "short", "skirt", |
| "socks", "sweater", "swimwear", "top", "underwear", |
| ], verbose=False) |
|
|
|
|
| @dataclass |
| class RuntimeConfig: |
| color_emb_dim: int = DEFAULT_COLOR_EMB_DIM |
| hierarchy_emb_dim: int = DEFAULT_HIERARCHY_EMB_DIM |
| main_emb_dim: int = DEFAULT_MAIN_EMB_DIM |
| main_model_path: str = DEFAULT_MAIN_MODEL_PATH |
| device: torch.device = DEFAULT_DEVICE |
|
|
| DEFAULT_NUM_EXAMPLES = 10000 |
| DEFAULT_NUM_PRINTED = 3 |
|
|
| COLORS = [ |
| "yellow", "blue", "red", "green", "black", "white", "pink", "purple", "brown", "orange", |
| ] |
| HIERARCHIES = [ |
| "dress", "shirt", "pants", "skirt", "jacket", "coat", "jeans", "sweater", "shorts", "top", |
| ] |
|
|
|
|
| LONG_TEXT_TEMPLATES = [ |
| "{color} {hierarchy}", |
| "{color} {hierarchy} with buttons", |
| "{color} {hierarchy} in cotton", |
| "casual {color} {hierarchy} for women", |
| "elegant {color} {hierarchy} with pockets", |
| ] |
|
|
|
|
| def build_text_query(color: str, hierarchy: str) -> str: |
| template = random.choice(LONG_TEXT_TEMPLATES) |
| return template.format(color=color, hierarchy=hierarchy) |
|
|
|
|
| def resolve_runtime_config() -> RuntimeConfig: |
| """Resolve config from local config.py if available, else use defaults.""" |
| cfg = RuntimeConfig() |
| try: |
| import config |
|
|
| cfg.color_emb_dim = getattr(config, "color_emb_dim", cfg.color_emb_dim) |
| cfg.hierarchy_emb_dim = getattr(config, "hierarchy_emb_dim", cfg.hierarchy_emb_dim) |
| cfg.main_emb_dim = getattr(config, "main_emb_dim", cfg.main_emb_dim) |
| cfg.main_model_path = getattr(config, "main_model_path", cfg.main_model_path) |
| cfg.device = getattr(config, "device", cfg.device) |
| except Exception: |
| if torch.cuda.is_available(): |
| cfg.device = torch.device("cuda") |
| elif torch.backends.mps.is_available(): |
| cfg.device = torch.device("mps") |
| else: |
| cfg.device = torch.device("cpu") |
|
|
| return cfg |
|
|
|
|
| def load_main_model(device: torch.device, main_model_path: str) -> Tuple[CLIPModelTransformers, CLIPProcessor]: |
| """Load GAP-CLIP from local checkpoint path only.""" |
| model_path = Path(main_model_path) |
| if not model_path.exists(): |
| raise FileNotFoundError(f"Main model checkpoint not found: {main_model_path}") |
|
|
| clip_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" |
| model = CLIPModelTransformers.from_pretrained(clip_name) |
| checkpoint = torch.load(str(model_path), map_location=device) |
| if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: |
| model.load_state_dict(checkpoint["model_state_dict"], strict=False) |
| else: |
| model.load_state_dict(checkpoint, strict=False) |
| model = model.to(device) |
| model.eval() |
| processor = CLIPProcessor.from_pretrained(clip_name) |
| return model, processor |
|
|
|
|
| def encode_text(model, processor, text_queries, device): |
| """Encode text queries into embeddings (unnormalized).""" |
| if isinstance(text_queries, str): |
| text_queries = [text_queries] |
| inputs = processor(text=text_queries, return_tensors="pt", padding=True, truncation=True) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| with torch.no_grad(): |
| text_features = model.get_text_features(**inputs) |
| return text_features |
|
|
|
|
| def encode_image(model, processor, images, device): |
| """Encode images into embeddings (unnormalized).""" |
| if not isinstance(images, list): |
| images = [images] |
| inputs = processor(images=images, return_tensors="pt") |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| with torch.no_grad(): |
| image_features = model.get_image_features(**inputs) |
| return image_features |
|
|
|
|
| def get_text_embedding( |
| model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, text: str) -> torch.Tensor: |
| """Normalized single text embedding (shape: [512]).""" |
| return F.normalize(encode_text(model, processor, text, device), dim=-1).squeeze(0) |
|
|
|
|
| def cosine(a: torch.Tensor, b: torch.Tensor) -> float: |
| return F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0), dim=1).item() |
|
|
|
|
| def delta_percent(reference: float, value: float) -> float: |
| """Relative delta in percent: (value-reference)/|reference|*100.""" |
| denom = max(abs(reference), 1e-8) |
| return ((value - reference) / denom) * 100.0 |
|
|
|
|
| def format_bool(ok: bool) -> str: |
| return "PASS" if ok else "FAIL" |
|
|
|
|
| def print_table(title: str, headers: List[str], rows: List[List[str]]) -> None: |
| print("\n" + "=" * 120) |
| print(title) |
| print("=" * 120) |
| all_rows = [headers] + rows |
| col_widths = [max(len(str(r[i])) for r in all_rows) for i in range(len(headers))] |
|
|
| def fmt(row: List[str]) -> str: |
| return " | ".join(str(v).ljust(col_widths[i]) for i, v in enumerate(row)) |
|
|
| print(fmt(headers)) |
| print("-" * (sum(col_widths) + 3 * (len(headers) - 1))) |
| for row in rows: |
| print(fmt(row)) |
|
|
|
|
| def run_test_a( |
| model: CLIPModelTransformers, |
| processor: CLIPProcessor, |
| cfg: RuntimeConfig, |
| num_examples: int, |
| num_printed: int, |
| test_name: str = "Test A") -> Dict[str, bool]: |
| """ |
| A: different colors + same hierarchy. |
| Expect hierarchy subspace to be more similar than color subspace. |
| """ |
| positive_pairs: List[Tuple[str, str]] = [] |
| negative_pairs: List[Tuple[str, str]] = [] |
| for _ in range(num_examples): |
| hierarchy = random.choice(HIERARCHIES) |
| c1, c2 = random.sample(COLORS, 2) |
| negative_hierarchy = random.choice([h for h in HIERARCHIES if h != hierarchy]) |
| positive_pairs.append((build_text_query(c1, hierarchy), build_text_query(c2, hierarchy))) |
| negative_pairs.append((build_text_query(c1, hierarchy), build_text_query(c2, negative_hierarchy))) |
|
|
| rows: List[List[str]] = [] |
| pair_outcomes: List[bool] = [] |
| full512_outcomes: List[bool] = [] |
| hier_gt_full_outcomes: List[bool] = [] |
| hier_gt_color_outcomes: List[bool] = [] |
| delta_color_vs_full_values: List[float] = [] |
| delta_hier_vs_full_values: List[float] = [] |
|
|
| for (left, right), (_, negative_right) in zip(positive_pairs, negative_pairs): |
| emb_left = get_text_embedding(model, processor, cfg.device, left) |
| emb_right = get_text_embedding(model, processor, cfg.device, right) |
| emb_negative_right = get_text_embedding(model, processor, cfg.device, negative_right) |
|
|
| left_color = emb_left[: cfg.color_emb_dim] |
| right_color = emb_right[: cfg.color_emb_dim] |
| left_hier = emb_left[cfg.color_emb_dim : cfg.color_emb_dim + cfg.hierarchy_emb_dim] |
| right_hier = emb_right[cfg.color_emb_dim : cfg.color_emb_dim + cfg.hierarchy_emb_dim] |
|
|
| sim_color = cosine(left_color, right_color) |
| sim_hier = cosine(left_hier, right_hier) |
| sim_full512 = cosine(emb_left, emb_right) |
| sim_full512_negative = cosine(emb_left, emb_negative_right) |
| delta_color_vs_full_pct = delta_percent(sim_full512, sim_color) |
| delta_hier_vs_full_pct = delta_percent(sim_full512, sim_hier) |
| delta_color_vs_full_values.append(delta_color_vs_full_pct) |
| delta_hier_vs_full_values.append(delta_hier_vs_full_pct) |
|
|
| hierarchy_higher_than_full = sim_hier > sim_full512 |
| hierarchy_higher_than_color = sim_hier > sim_color |
| pair_ok = hierarchy_higher_than_full and hierarchy_higher_than_color |
| pair_outcomes.append(pair_ok) |
| hier_gt_full_outcomes.append(hierarchy_higher_than_full) |
| hier_gt_color_outcomes.append(hierarchy_higher_than_color) |
| full512_outcomes.append(sim_full512 > sim_full512_negative) |
|
|
| rows.append( |
| [ |
| f"{left} vs {right}", |
| f"{sim_color:.4f}", |
| f"{sim_hier:.4f}", |
| f"{sim_full512:.4f}", |
| f"{delta_color_vs_full_pct:+.2f}%", |
| f"{delta_hier_vs_full_pct:+.2f}%", |
| format_bool(pair_ok), |
| ] |
| ) |
|
|
| print_table( |
| f"{test_name}: Different colors, same hierarchy (showing {min(num_printed, len(rows))}/{len(rows)} examples)", |
| [ |
| "Pair", |
| "CosSim first16(color)", |
| "CosSim hier64", |
| "CosSim full512", |
| "Delta first16 vs full512 (%)", |
| "Delta hier64 vs full512 (%)", |
| "Result", |
| ], |
| rows[:num_printed], |
| ) |
|
|
| overall = all(pair_outcomes) |
| pass_rate = sum(pair_outcomes) / len(pair_outcomes) |
| full512_accuracy = sum(full512_outcomes) / len(full512_outcomes) |
| hier_gt_full_rate = sum(hier_gt_full_outcomes) / len(hier_gt_full_outcomes) |
| hier_gt_color_rate = sum(hier_gt_color_outcomes) / len(hier_gt_color_outcomes) |
| avg_delta_color_vs_full = sum(delta_color_vs_full_values) / len(delta_color_vs_full_values) |
| avg_delta_hier_vs_full = sum(delta_hier_vs_full_values) / len(delta_hier_vs_full_values) |
| print(f"{test_name} aggregate: {sum(pair_outcomes)}/{len(pair_outcomes)} passed ({pass_rate:.2%})") |
| print(f" sub-condition hier > full512: {sum(hier_gt_full_outcomes)}/{len(hier_gt_full_outcomes)} ({hier_gt_full_rate:.2%})") |
| print(f" sub-condition hier > color: {sum(hier_gt_color_outcomes)}/{len(hier_gt_color_outcomes)} ({hier_gt_color_rate:.2%})") |
| print( |
| f"{test_name} full512 pair-discrimination accuracy " |
| f"(same-hierarchy > different-hierarchy): {sum(full512_outcomes)}/{len(full512_outcomes)} " |
| f"({full512_accuracy:.2%})" |
| ) |
| print( |
| f"{test_name} avg deltas: " |
| f"first16 vs full512 = {avg_delta_color_vs_full:+.2f}%, " |
| f"hier64 vs full512 = {avg_delta_hier_vs_full:+.2f}%" |
| ) |
| return { |
| "overall": overall, |
| "accuracy_full512": full512_accuracy, |
| "pass_rate": pass_rate, |
| "hier_gt_full_rate": hier_gt_full_rate, |
| "hier_gt_color_rate": hier_gt_color_rate, |
| "avg_delta_color_vs_full": avg_delta_color_vs_full, |
| "avg_delta_hier_vs_full": avg_delta_hier_vs_full, |
| } |
|
|
|
|
| def run_test_b( |
| model: CLIPModelTransformers, |
| processor: CLIPProcessor, |
| cfg: RuntimeConfig, |
| num_examples: int, |
| num_printed: int, |
| test_name: str = "Test B",) -> Dict[str, bool]: |
| """ |
| B: same color + different hierarchies. |
| Expect similarity in first16 (color) to be higher than full512. |
| """ |
| positive_pairs: List[Tuple[str, str]] = [] |
| negative_pairs: List[Tuple[str, str]] = [] |
| for _ in range(num_examples): |
| color = random.choice(COLORS) |
| h1, h2 = random.sample(HIERARCHIES, 2) |
| negative_color = random.choice([c for c in COLORS if c != color]) |
| positive_pairs.append((build_text_query(color, h1), build_text_query(color, h2))) |
| negative_pairs.append((build_text_query(color, h1), build_text_query(negative_color, h2))) |
|
|
| rows: List[List[str]] = [] |
| pair_outcomes: List[bool] = [] |
| full512_outcomes: List[bool] = [] |
| color_gt_full_outcomes: List[bool] = [] |
| color_gt_hier_outcomes: List[bool] = [] |
| delta_color_vs_full_values: List[float] = [] |
| delta_hier_vs_full_values: List[float] = [] |
|
|
| for (left, right), (_, negative_right) in zip(positive_pairs, negative_pairs): |
| emb_left = get_text_embedding(model, processor, cfg.device, left) |
| emb_right = get_text_embedding(model, processor, cfg.device, right) |
| emb_negative_right = get_text_embedding(model, processor, cfg.device, negative_right) |
|
|
| sim_512 = cosine(emb_left, emb_right) |
| sim_16 = cosine(emb_left[: cfg.color_emb_dim], emb_right[: cfg.color_emb_dim]) |
| sim_hier = cosine( |
| emb_left[cfg.color_emb_dim : cfg.color_emb_dim + cfg.hierarchy_emb_dim], |
| emb_right[cfg.color_emb_dim : cfg.color_emb_dim + cfg.hierarchy_emb_dim], |
| ) |
| sim_512_negative = cosine(emb_left, emb_negative_right) |
| delta_color_vs_full_pct = delta_percent(sim_512, sim_16) |
| delta_hier_vs_full_pct = delta_percent(sim_512, sim_hier) |
| delta_color_vs_full_values.append(delta_color_vs_full_pct) |
| delta_hier_vs_full_values.append(delta_hier_vs_full_pct) |
|
|
| first16_higher_than_full = sim_16 > sim_512 |
| color_higher_than_hier = sim_16 > sim_hier |
| pair_ok = first16_higher_than_full and color_higher_than_hier |
| pair_outcomes.append(pair_ok) |
| color_gt_full_outcomes.append(first16_higher_than_full) |
| color_gt_hier_outcomes.append(color_higher_than_hier) |
| full512_outcomes.append(sim_512 > sim_512_negative) |
|
|
| rows.append( |
| [ |
| f"{left} vs {right}", |
| f"{sim_16:.4f}", |
| f"{sim_hier:.4f}", |
| f"{sim_512:.4f}", |
| f"{delta_color_vs_full_pct:+.2f}%", |
| f"{delta_hier_vs_full_pct:+.2f}%", |
| format_bool(pair_ok), |
| ] |
| ) |
|
|
| print_table( |
| f"{test_name}: Same color, different hierarchies (showing {min(num_printed, len(rows))}/{len(rows)} examples)", |
| [ |
| "Pair", |
| "CosSim first16(color)", |
| "CosSim hier64", |
| "CosSim full512", |
| "Delta first16 vs full512 (%)", |
| "Delta hier64 vs full512 (%)", |
| "Result", |
| ], |
| rows[:num_printed], |
| ) |
|
|
| overall = all(pair_outcomes) |
| pass_rate = sum(pair_outcomes) / len(pair_outcomes) |
| full512_accuracy = sum(full512_outcomes) / len(full512_outcomes) |
| color_gt_full_rate = sum(color_gt_full_outcomes) / len(color_gt_full_outcomes) |
| color_gt_hier_rate = sum(color_gt_hier_outcomes) / len(color_gt_hier_outcomes) |
| avg_delta_color_vs_full = sum(delta_color_vs_full_values) / len(delta_color_vs_full_values) |
| avg_delta_hier_vs_full = sum(delta_hier_vs_full_values) / len(delta_hier_vs_full_values) |
| print(f"{test_name} aggregate: {sum(pair_outcomes)}/{len(pair_outcomes)} passed ({pass_rate:.2%})") |
| print(f" sub-condition color > full512: {sum(color_gt_full_outcomes)}/{len(color_gt_full_outcomes)} ({color_gt_full_rate:.2%})") |
| print(f" sub-condition color > hier: {sum(color_gt_hier_outcomes)}/{len(color_gt_hier_outcomes)} ({color_gt_hier_rate:.2%})") |
| print( |
| f"{test_name} full512 pair-discrimination accuracy " |
| f"(same-color > different-color): {sum(full512_outcomes)}/{len(full512_outcomes)} " |
| f"({full512_accuracy:.2%})" |
| ) |
| print( |
| f"{test_name} avg deltas: " |
| f"first16 vs full512 = {avg_delta_color_vs_full:+.2f}%, " |
| f"hier64 vs full512 = {avg_delta_hier_vs_full:+.2f}%" |
| ) |
| return { |
| "overall": overall, |
| "accuracy_full512": full512_accuracy, |
| "pass_rate": pass_rate, |
| "color_gt_full_rate": color_gt_full_rate, |
| "color_gt_hier_rate": color_gt_hier_rate, |
| "avg_delta_color_vs_full": avg_delta_color_vs_full, |
| "avg_delta_hier_vs_full": avg_delta_hier_vs_full, |
| } |
|
|
|
|
|
|
| def run_test_c( |
| model: CLIPModelTransformers, |
| processor: CLIPProcessor, |
| cfg: RuntimeConfig, |
| num_examples: int, |
| num_printed: int, |
| test_name: str = "Test C",) -> Dict[str, object]: |
| """ |
| C: Subspace Decomposition Consistency. |
| Encode a full description (e.g. "red dress in cotton"), a standalone color |
| ("red"), and a standalone hierarchy ("dress"). Then verify: |
| - The color subspace (first 16D) of the full embedding aligns with the |
| color-only embedding more than with the hierarchy-only embedding. |
| - The hierarchy subspace (dims 16-80) of the full embedding aligns with |
| the hierarchy-only embedding more than with the color-only embedding. |
| """ |
| rows: List[List[str]] = [] |
| color_match_outcomes: List[bool] = [] |
| hier_match_outcomes: List[bool] = [] |
| pair_outcomes: List[bool] = [] |
| sim_color_match_values: List[float] = [] |
| sim_color_cross_values: List[float] = [] |
| sim_hier_match_values: List[float] = [] |
| sim_hier_cross_values: List[float] = [] |
|
|
| for _ in range(num_examples): |
| color = random.choice(COLORS) |
| hierarchy = random.choice(HIERARCHIES) |
| full_text = build_text_query(color, hierarchy) |
|
|
| emb_full = get_text_embedding(model, processor, cfg.device, full_text) |
| emb_color = get_text_embedding(model, processor, cfg.device, color) |
| emb_hier = get_text_embedding(model, processor, cfg.device, hierarchy) |
|
|
| |
| full_color = emb_full[: cfg.color_emb_dim] |
| color_color = emb_color[: cfg.color_emb_dim] |
| hier_color = emb_hier[: cfg.color_emb_dim] |
|
|
| |
| full_hier = emb_full[cfg.color_emb_dim : cfg.color_emb_dim + cfg.hierarchy_emb_dim] |
| color_hier = emb_color[cfg.color_emb_dim : cfg.color_emb_dim + cfg.hierarchy_emb_dim] |
| hier_hier = emb_hier[cfg.color_emb_dim : cfg.color_emb_dim + cfg.hierarchy_emb_dim] |
|
|
| |
| sim_color_match = cosine(full_color, color_color) |
| sim_hier_match = cosine(full_hier, hier_hier) |
|
|
| |
| sim_color_cross = cosine(full_color, hier_color) |
| sim_hier_cross = cosine(full_hier, color_hier) |
|
|
| sim_color_match_values.append(sim_color_match) |
| sim_color_cross_values.append(sim_color_cross) |
| sim_hier_match_values.append(sim_hier_match) |
| sim_hier_cross_values.append(sim_hier_cross) |
|
|
| color_ok = sim_color_match > sim_color_cross |
| hier_ok = sim_hier_match > sim_hier_cross |
| pair_ok = color_ok and hier_ok |
| color_match_outcomes.append(color_ok) |
| hier_match_outcomes.append(hier_ok) |
| pair_outcomes.append(pair_ok) |
|
|
| rows.append([ |
| full_text, |
| color, |
| hierarchy, |
| f"{sim_color_match:.4f}", |
| f"{sim_color_cross:.4f}", |
| f"{sim_hier_match:.4f}", |
| f"{sim_hier_cross:.4f}", |
| format_bool(pair_ok), |
| ]) |
|
|
| print_table( |
| f"{test_name}: Subspace Decomposition Consistency " |
| f"(showing {min(num_printed, len(rows))}/{len(rows)} examples)", |
| [ |
| "Full description", |
| "Color", |
| "Hierarchy", |
| "ColorSub match", |
| "ColorSub cross", |
| "HierSub match", |
| "HierSub cross", |
| "Result", |
| ], |
| rows[:num_printed], |
| ) |
|
|
| pass_rate = sum(pair_outcomes) / len(pair_outcomes) |
| color_rate = sum(color_match_outcomes) / len(color_match_outcomes) |
| hier_rate = sum(hier_match_outcomes) / len(hier_match_outcomes) |
| avg_color_match = sum(sim_color_match_values) / len(sim_color_match_values) |
| avg_color_cross = sum(sim_color_cross_values) / len(sim_color_cross_values) |
| avg_hier_match = sum(sim_hier_match_values) / len(sim_hier_match_values) |
| avg_hier_cross = sum(sim_hier_cross_values) / len(sim_hier_cross_values) |
|
|
| print(f"{test_name} aggregate: {sum(pair_outcomes)}/{len(pair_outcomes)} passed ({pass_rate:.2%})") |
| print(f" sub-condition color_match > color_cross: {sum(color_match_outcomes)}/{len(color_match_outcomes)} ({color_rate:.2%})") |
| print(f" sub-condition hier_match > hier_cross: {sum(hier_match_outcomes)}/{len(hier_match_outcomes)} ({hier_rate:.2%})") |
| print( |
| f"{test_name} avg similarities: " |
| f"color_match={avg_color_match:.4f}, color_cross={avg_color_cross:.4f}, " |
| f"hier_match={avg_hier_match:.4f}, hier_cross={avg_hier_cross:.4f}" |
| ) |
|
|
| return { |
| "overall": all(pair_outcomes), |
| "pass_rate": pass_rate, |
| "color_match_rate": color_rate, |
| "hier_match_rate": hier_rate, |
| "avg_color_match": avg_color_match, |
| "avg_color_cross": avg_color_cross, |
| "avg_hier_match": avg_hier_match, |
| "avg_hier_cross": avg_hier_cross, |
| } |
|
|
|
|
| FASHION_MNIST_LABELS = { |
| 0: "top", |
| 1: "pant", |
| 2: "sweater", |
| 3: "dress", |
| 4: "coat", |
| 5: "shoes", |
| 6: "shirt", |
| 7: "shoes", |
| 8: "accessories", |
| 9: "shoes", |
| } |
|
|
| |
| |
| |
| |
| FASHION_MNIST_ORIGINAL_LABELS = { |
| 0: "T-shirt", |
| 1: "Trouser", |
| 2: "Pullover", |
| 3: "Dress", |
| 4: "Coat", |
| 5: "Sandal", |
| 6: "Shirt", |
| 7: "Sneaker", |
| 8: "Bag", |
| 9: "Ankle boot", |
| } |
|
|
| FASHION_MNIST_CSV = "data/fashion-mnist_test.csv" |
| INTERNAL_DATASET_CSV = "data/data.csv" |
|
|
|
|
| def fashion_mnist_pixels_to_tensor(pixel_values: np.ndarray, image_size: int = 224) -> torch.Tensor: |
| img_array = pixel_values.reshape(28, 28).astype(np.uint8) |
| img_array = np.stack([img_array] * 3, axis=-1) |
| image = Image.fromarray(img_array) |
| transform = transforms.Compose([ |
| transforms.Resize((image_size, image_size)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
| return transform(image) |
|
|
|
|
| def get_image_embedding( |
| model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, image_tensor: torch.Tensor |
| ) -> torch.Tensor: |
| """Normalized image embedding from a preprocessed tensor (shape: [512]).""" |
| image_tensor = image_tensor.unsqueeze(0).to(device) |
| |
| from torchvision.transforms.functional import to_pil_image |
| pil_img = to_pil_image(image_tensor.squeeze(0).cpu()) |
| return F.normalize(encode_image(model, processor, pil_img, device), dim=-1).squeeze(0) |
|
|
|
|
| def get_image_embedding_from_pil( |
| model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, image: Image.Image |
| ) -> torch.Tensor: |
| """Normalized image embedding from a PIL image (shape: [512]).""" |
| return F.normalize(encode_image(model, processor, image, device), dim=-1).squeeze(0) |
|
|
|
|
| def get_text_embeddings_batch( |
| model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, texts: List[str] |
| ) -> torch.Tensor: |
| """Normalized text embeddings for a batch (shape: [N, 512]).""" |
| return F.normalize(encode_text(model, processor, texts, device), dim=-1) |
|
|
|
|
| def get_prompt_ensembled_text_embeddings( |
| model: CLIPModelTransformers, |
| processor: CLIPProcessor, |
| device: torch.device, |
| labels: List[str], |
| templates: List[str], |
| ) -> torch.Tensor: |
| """Encode labels with multiple prompt templates and average embeddings.""" |
| all_prompt_embs: List[torch.Tensor] = [] |
| for template in templates: |
| prompts = [template.format(label=label) for label in labels] |
| all_prompt_embs.append(get_text_embeddings_batch(model, processor, device, prompts)) |
| stacked = torch.stack(all_prompt_embs, dim=0) |
| ensembled = stacked.mean(dim=0) |
| ensembled = F.normalize(ensembled, dim=-1) |
| return ensembled |
|
|
|
|
| def get_internal_label_prior(labels: List[str]) -> torch.Tensor: |
| """ |
| Compute label prior from internal dataset hierarchy frequency. |
| Falls back to uniform when internal CSV is unavailable. |
| """ |
| csv_file = Path(INTERNAL_DATASET_CSV) |
| if not csv_file.exists(): |
| return torch.ones(len(labels), dtype=torch.float32) / max(len(labels), 1) |
| try: |
| df = pd.read_csv(INTERNAL_DATASET_CSV, usecols=["hierarchy"]).dropna() |
| except Exception: |
| return torch.ones(len(labels), dtype=torch.float32) / max(len(labels), 1) |
| if len(df) == 0: |
| return torch.ones(len(labels), dtype=torch.float32) / max(len(labels), 1) |
|
|
| norm_labels = [normalize_hierarchy_label(v) for v in df["hierarchy"].astype(str)] |
| counts = pd.Series(norm_labels).value_counts().to_dict() |
| smooth = 1e-3 |
| probs = torch.tensor([float(counts.get(label, 0.0)) + smooth for label in labels], dtype=torch.float32) |
| probs = probs / probs.sum() |
| return probs |
|
|
|
|
| def get_adaptive_label_prior(labels: List[str]) -> Tuple[torch.Tensor, float]: |
| """ |
| Compute label prior with adaptive strength based on overlap between |
| candidate labels and the training distribution. When most candidate |
| labels are out-of-domain, the recommended weight drops toward zero so |
| the prior does not penalise novel categories. |
| """ |
| csv_file = Path(INTERNAL_DATASET_CSV) |
| uniform = torch.ones(len(labels), dtype=torch.float32) / max(len(labels), 1) |
| if not csv_file.exists(): |
| return uniform, 0.0 |
| try: |
| df = pd.read_csv(INTERNAL_DATASET_CSV, usecols=["hierarchy"]).dropna() |
| except Exception: |
| return uniform, 0.0 |
| if len(df) == 0: |
| return uniform, 0.0 |
|
|
| norm_labels = [normalize_hierarchy_label(v) for v in df["hierarchy"].astype(str)] |
| counts = pd.Series(norm_labels).value_counts().to_dict() |
| known_labels = set(counts.keys()) |
| overlap = sum(1 for l in labels if l in known_labels) / max(len(labels), 1) |
| total_count = sum(counts.values()) |
| default_prob = 1.0 / max(len(labels), 1) |
|
|
| probs = torch.tensor( |
| [ |
| counts.get(label, 0.0) / total_count if label in known_labels else default_prob |
| for label in labels |
| ], |
| dtype=torch.float32, |
| ) |
| probs = probs / probs.sum() |
| recommended_weight = 0.15 * (overlap ** 2) |
| return probs, recommended_weight |
|
|
|
|
| def zero_shot_fashion_mnist( |
| model, |
| processor, |
| device, |
| batch_size: int = 64, |
| data_root: str = "./data") -> float: |
| """Notebook-equivalent zero-shot accuracy on all Fashion-MNIST test samples.""" |
| dataset = datasets.FashionMNIST( |
| root=data_root, train=False, download=True, |
| transform=transforms.Grayscale(num_output_channels=3), |
| ) |
| loader = DataLoader( |
| dataset, batch_size=batch_size, shuffle=False, |
| collate_fn=lambda batch: ( |
| [item[0] for item in batch], |
| torch.tensor([item[1] for item in batch]), |
| ), |
| ) |
|
|
| prompts = [f"a photo of a {label}" for label in dataset.classes] |
| text_embs = encode_text(model, processor, prompts, device).to(device).float() |
| text_embs = F.normalize(text_embs, dim=-1) |
|
|
| correct = 0 |
| total = 0 |
|
|
| for pil_images, labels in tqdm(loader, desc="Zero-shot Fashion-MNIST"): |
| img_embs = encode_image(model, processor, pil_images, device) |
| img_embs = img_embs.to(device).float() |
| img_embs = F.normalize(img_embs, dim=-1) |
|
|
| sim = img_embs @ text_embs.T |
| preds = sim.argmax(dim=-1).cpu() |
|
|
| correct += (preds == labels).sum().item() |
| total += labels.size(0) |
|
|
| accuracy = correct / total |
| print(f"Zero-shot accuracy on Fashion MNIST: {accuracy:.4f} ({correct}/{total})") |
| return accuracy |
|
|
|
|
|
|
| def zero_shot_kagl( |
| model, |
| processor, |
| device, |
| batch_size: int = 64, |
| num_examples: int = 10000, |
| ) -> Optional[Dict[str, float]]: |
| """Notebook-equivalent zero-shot accuracy/F1 on KAGL Marqo (category2).""" |
| try: |
| from datasets import load_dataset |
| except Exception: |
| print("Skipping zero_shot_kagl: datasets package not available") |
| return None |
|
|
| try: |
| dataset = load_dataset("Marqo/KAGL", split="data") |
| except Exception as exc: |
| print(f"Skipping zero_shot_kagl: failed to load dataset ({exc})") |
| return None |
|
|
| dataset = dataset.shuffle(seed=42).select(range(min(num_examples, len(dataset)))) |
|
|
| pil_images: List[Image.Image] = [] |
| labels_text: List[str] = [] |
| for item in dataset: |
| raw_label = item.get("category2") |
| image_obj = item.get("image") |
| if raw_label is None or image_obj is None: |
| continue |
|
|
| if hasattr(image_obj, "convert"): |
| image = image_obj.convert("RGB") |
| elif isinstance(image_obj, dict) and "bytes" in image_obj: |
| image = Image.open(BytesIO(image_obj["bytes"])).convert("RGB") |
| else: |
| continue |
| pil_images.append(image) |
| labels_text.append(str(raw_label).strip()) |
|
|
| if not pil_images: |
| print("Skipping zero_shot_kagl: no valid samples") |
| return None |
|
|
| candidate_labels = sorted(set(labels_text)) |
| label_to_idx = {label: idx for idx, label in enumerate(candidate_labels)} |
| all_labels = np.array([label_to_idx[label] for label in labels_text], dtype=np.int64) |
|
|
| prompts = [f"a photo of a {label}" for label in candidate_labels] |
| text_embs = encode_text(model, processor, prompts, device).to(device).float() |
| text_embs = F.normalize(text_embs, dim=-1) |
|
|
| all_preds: List[np.ndarray] = [] |
| for start in tqdm(range(0, len(pil_images), batch_size), desc="Zero-shot KAGL"): |
| batch_images = pil_images[start : start + batch_size] |
| img_embs = encode_image(model, processor, batch_images, device).to(device).float() |
| img_embs = F.normalize(img_embs, dim=-1) |
| sim = img_embs @ text_embs.T |
| preds = sim.argmax(dim=-1).cpu().numpy() |
| all_preds.append(preds) |
|
|
| pred_array = np.concatenate(all_preds, axis=0) if all_preds else np.array([], dtype=np.int64) |
| accuracy = float((pred_array == all_labels).mean()) if len(all_labels) else 0.0 |
| weighted_f1 = f1_score(all_labels, pred_array, average="weighted") if len(all_labels) else 0.0 |
| print(f"KAGL accuracy: {accuracy:.4f}") |
| print(f"KAGL weighted macro F1: {weighted_f1:.4f}") |
| return {"accuracy": accuracy, "weighted_f1": float(weighted_f1)} |
|
|
|
|
| def zero_shot_internal( |
| model, |
| processor, |
| device, |
| batch_size: int = 64, |
| num_examples: int = 10000, |
| csv_path: str = INTERNAL_DATASET_CSV,) -> Optional[Dict[str, float]]: |
| """Notebook-equivalent zero-shot accuracy/F1 on internal dataset.""" |
| csv_file = Path(csv_path) |
| if not csv_file.exists(): |
| print(f"Skipping zero_shot_internal: {csv_path} not found") |
| return None |
|
|
| df = pd.read_csv(csv_file) |
| use_local = "local_image_path" in df.columns |
| required_cols = {"hierarchy", "local_image_path"} if use_local else {"hierarchy", "image_url"} |
| if not required_cols.issubset(df.columns): |
| print(f"Skipping zero_shot_internal: missing required columns {required_cols}") |
| return None |
|
|
| img_col = "local_image_path" if use_local else "image_url" |
| df = df.dropna(subset=["hierarchy", img_col]).sample(frac=1.0, random_state=42) |
| pil_images: List[Image.Image] = [] |
| labels_text: List[str] = [] |
| for _, row in df.iterrows(): |
| if len(pil_images) >= num_examples: |
| break |
| try: |
| if use_local: |
| img_path = Path(str(row["local_image_path"])) |
| if not img_path.exists(): |
| continue |
| image = Image.open(img_path).convert("RGB") |
| else: |
| response = requests.get(str(row["image_url"]), timeout=5) |
| response.raise_for_status() |
| image = Image.open(BytesIO(response.content)).convert("RGB") |
| except Exception: |
| continue |
| label = normalize_hierarchy_label(str(row["hierarchy"])) |
| pil_images.append(image) |
| labels_text.append(label) |
|
|
| if not pil_images: |
| print("Skipping zero_shot_internal: no valid samples") |
| return None |
|
|
| candidate_labels = sorted(set(labels_text)) |
| label_to_idx = {label: idx for idx, label in enumerate(candidate_labels)} |
| all_labels = np.array([label_to_idx[label] for label in labels_text], dtype=np.int64) |
|
|
| prompts = [f"a photo of a {label}" for label in candidate_labels] |
| text_embs = encode_text(model, processor, prompts, device).to(device).float() |
| text_embs = F.normalize(text_embs, dim=-1) |
|
|
| all_preds: List[np.ndarray] = [] |
| for start in tqdm(range(0, len(pil_images), batch_size), desc="Zero-shot Internal"): |
| batch_images = pil_images[start : start + batch_size] |
| img_embs = encode_image(model, processor, batch_images, device).to(device).float() |
| img_embs = F.normalize(img_embs, dim=-1) |
| sim = img_embs @ text_embs.T |
| preds = sim.argmax(dim=-1).cpu().numpy() |
| all_preds.append(preds) |
|
|
| pred_array = np.concatenate(all_preds, axis=0) if all_preds else np.array([], dtype=np.int64) |
| accuracy = float((pred_array == all_labels).mean()) if len(all_labels) else 0.0 |
| weighted_f1 = f1_score(all_labels, pred_array, average="weighted") if len(all_labels) else 0.0 |
| print(f"Internal accuracy: {accuracy:.4f}") |
| print(f"Internal weighted macro F1: {weighted_f1:.4f}") |
| return {"accuracy": accuracy, "weighted_f1": float(weighted_f1)} |
|
|
|
|
| def normalize_hierarchy_label(raw_label: str) -> str: |
| """Map dataset category strings to internal hierarchy labels.""" |
| label = str(raw_label).strip().lower() |
| synonyms = { |
| "t-shirt/top": "top", |
| "top": "top", |
| "tee": "top", |
| "t-shirt": "top", |
| "shirt": "shirt", |
| "shirts": "shirt", |
| "pullover": "sweater", |
| "sweater": "sweater", |
| "coat": "coat", |
| "jacket": "jacket", |
| "outerwear": "coat", |
| "trouser": "pant", |
| "trousers": "pant", |
| "pants": "pant", |
| "pant": "pant", |
| "jeans": "pant", |
| "dress": "dress", |
| "skirt": "skirt", |
| "shorts": "short", |
| "short": "short", |
| "sandal": "shoes", |
| "sneaker": "shoes", |
| "ankle boot": "shoes", |
| "shoe": "shoes", |
| "shoes": "shoes", |
| "flip flops": "shoes", |
| "footwear": "shoes", |
| "shoe accessories": "shoes", |
| "bag": "accessories", |
| "bags": "accessories", |
| "accessory": "accessories", |
| "accessories": "accessories", |
| "belts": "accessories", |
| "eyewear": "accessories", |
| "jewellery": "accessories", |
| "jewelry": "accessories", |
| "headwear": "accessories", |
| "wallets": "accessories", |
| "watches": "accessories", |
| "mufflers": "accessories", |
| "scarves": "accessories", |
| "stoles": "accessories", |
| "ties": "accessories", |
| "topwear": "top", |
| "bottomwear": "pant", |
| "innerwear": "underwear", |
| "loungewear and nightwear": "underwear", |
| "saree": "dress", |
| "boots": "shoes", |
| "outer": "coat", |
| "sunglasses": "accessories", |
| "scarf & tie": "accessories", |
| "scarf/tie": "accessories", |
| "belt": "accessories", |
| } |
| exact = synonyms.get(label, None) |
| if exact is not None: |
| return exact |
|
|
| |
| |
| |
| result = _HIERARCHY_EXTRACTOR.extract_hierarchy(label) |
| if result: |
| return result |
|
|
| |
| _EXTRA_KEYWORDS = [ |
| ("capri", "pant"), |
| ("denim", "pant"), |
| ("skinny", "pant"), |
| ("boyfriend", "pant"), |
| ("graphic", "top"), |
| ("longsleeve", "top"), |
| ("leather", "jacket"), |
| ] |
| for keyword, category in _EXTRA_KEYWORDS: |
| if keyword in label: |
| return category |
|
|
| return label |
|
|
|
|
|
|
| |
| MODANET_CATEGORIES = { |
| 1: "bag", 2: "belt", 3: "boots", 4: "footwear", 5: "outer", |
| 6: "dress", 7: "sunglasses", 8: "pants", 9: "top", 10: "shorts", |
| 11: "skirt", 12: "headwear", 13: "scarf/tie", |
| } |
|
|
| MODANET_ANNOTATIONS_JSON = "data/modanet_instances_train.json" |
| MODANET_IMAGES_DIR = "data/modanet_images/images" |
|
|
|
|
| def load_modanet_samples( |
| num_examples: int, |
| ) -> Tuple[List[Tuple[Image.Image, str]], List[Tuple[Image.Image, str]], List[Tuple[Image.Image, str]]]: |
| """Return (baseline_samples, gap_samples, color_samples) from ModaNet. |
| |
| Loads from local COCO JSON annotations + image directory. |
| Each image may have multiple annotations — we pick the largest bbox area. |
| """ |
| import json as _json |
|
|
| ann_path = Path(MODANET_ANNOTATIONS_JSON) |
| img_dir = Path(MODANET_IMAGES_DIR) |
|
|
| if not ann_path.exists(): |
| print(f" Skipping ModaNet: annotations not found at {MODANET_ANNOTATIONS_JSON}") |
| return [], [], [] |
| if not img_dir.exists(): |
| print(f" Skipping ModaNet: images directory not found at {MODANET_IMAGES_DIR}") |
| return [], [], [] |
|
|
| print(" Loading ModaNet annotations...") |
| with open(ann_path) as f: |
| coco = _json.load(f) |
|
|
| cat_map = {c["id"]: c["name"] for c in coco["categories"]} |
| img_map = {img["id"]: img["file_name"] for img in coco["images"]} |
|
|
| |
| best_per_image: Dict[int, Tuple[int, float]] = {} |
| for ann in coco["annotations"]: |
| img_id = ann["image_id"] |
| cat_id = ann["category_id"] |
| area = ann.get("area", 0) |
| if img_id not in best_per_image or area > best_per_image[img_id][1]: |
| best_per_image[img_id] = (cat_id, area) |
|
|
| |
| image_ids = list(best_per_image.keys()) |
| rng = random.Random(42) |
| rng.shuffle(image_ids) |
|
|
| baseline_samples: List[Tuple[Image.Image, str]] = [] |
| gap_samples: List[Tuple[Image.Image, str]] = [] |
|
|
| for img_id in image_ids: |
| if len(baseline_samples) >= num_examples: |
| break |
| file_name = img_map.get(img_id) |
| if file_name is None: |
| continue |
| img_path = img_dir / file_name |
| if not img_path.exists(): |
| continue |
| try: |
| image = Image.open(img_path).convert("RGB") |
| except Exception: |
| continue |
|
|
| cat_id, _ = best_per_image[img_id] |
| native_label = cat_map.get(cat_id, "unknown") |
| gap_label = normalize_hierarchy_label(native_label) |
| baseline_samples.append((image, native_label)) |
| gap_samples.append((image, gap_label)) |
|
|
| print(f" ModaNet: loaded {len(baseline_samples)} valid samples (from {len(best_per_image)} annotated images)") |
| return baseline_samples, gap_samples, [] |
|
|
|
|
| def zero_shot_modanet( |
| model, |
| processor, |
| device, |
| batch_size: int = 64, |
| num_examples: int = 10000, |
| use_gap_labels: bool = True, |
| ) -> Optional[Dict[str, float]]: |
| """Zero-shot accuracy/F1 on ModaNet dataset.""" |
| baseline_samples, gap_samples, _ = load_modanet_samples(num_examples) |
| samples = gap_samples if use_gap_labels else baseline_samples |
| if not samples: |
| print("Skipping zero_shot_modanet: no valid samples") |
| return None |
|
|
| pil_images = [img for img, _ in samples] |
| labels_text = [label for _, label in samples] |
|
|
| candidate_labels = sorted(set(labels_text)) |
| label_to_idx = {label: idx for idx, label in enumerate(candidate_labels)} |
| all_labels = np.array([label_to_idx[label] for label in labels_text], dtype=np.int64) |
|
|
| prompts = [f"a photo of a {label}" for label in candidate_labels] |
| text_embs = encode_text(model, processor, prompts, device).to(device).float() |
| text_embs = F.normalize(text_embs, dim=-1) |
|
|
| all_preds: List[np.ndarray] = [] |
| for start in tqdm(range(0, len(pil_images), batch_size), desc="Zero-shot ModaNet"): |
| batch_images = pil_images[start : start + batch_size] |
| img_embs = encode_image(model, processor, batch_images, device).to(device).float() |
| img_embs = F.normalize(img_embs, dim=-1) |
| sim = img_embs @ text_embs.T |
| preds = sim.argmax(dim=-1).cpu().numpy() |
| all_preds.append(preds) |
|
|
| pred_array = np.concatenate(all_preds, axis=0) if all_preds else np.array([], dtype=np.int64) |
| accuracy = float((pred_array == all_labels).mean()) if len(all_labels) else 0.0 |
| weighted_f1 = f1_score(all_labels, pred_array, average="weighted") if len(all_labels) else 0.0 |
| label_kind = "GAP" if use_gap_labels else "native" |
| print(f"ModaNet ({label_kind}) accuracy: {accuracy:.4f}") |
| print(f"ModaNet ({label_kind}) weighted macro F1: {weighted_f1:.4f}") |
| return {"accuracy": accuracy, "weighted_f1": float(weighted_f1)} |
|
|
|
|
| def main( |
| selected_tests: set[str], |
| model=None, |
| processor=None, |
| baseline_model=None, |
| baseline_processor=None, |
| ) -> None: |
| random.seed(42) |
| cfg = resolve_runtime_config() |
|
|
| if model is None or processor is None: |
| model_path = Path(cfg.main_model_path) |
| if not model_path.exists(): |
| raise FileNotFoundError(f"Main model checkpoint not found: {cfg.main_model_path}") |
| print("Loading model...") |
| print(f" device: {cfg.device}") |
| print(f" checkpoint: {cfg.main_model_path}") |
| print(f" dims: color={cfg.color_emb_dim}, hierarchy={cfg.hierarchy_emb_dim}, total={cfg.main_emb_dim}") |
| model, processor = load_main_model(cfg.device, cfg.main_model_path) |
| print("Model loaded.") |
| else: |
| print(f"Using pre-loaded GAP-CLIP model (dims: color={cfg.color_emb_dim}, hierarchy={cfg.hierarchy_emb_dim}, total={cfg.main_emb_dim})") |
|
|
| result_a: Optional[Dict[str, object]] = None |
| result_b: Optional[Dict[str, object]] = None |
| result_c: Optional[Dict[str, object]] = None |
| baseline_result_a: Optional[Dict[str, object]] = None |
| baseline_result_b: Optional[Dict[str, object]] = None |
| baseline_result_c: Optional[Dict[str, object]] = None |
|
|
| if baseline_model is None or baseline_processor is None: |
| if any(t in selected_tests for t in ("A", "B", "C", "D")): |
| print("\nLoading baseline model (patrickjohncyh/fashion-clip)...") |
| baseline_name = "patrickjohncyh/fashion-clip" |
| baseline_processor = CLIPProcessor.from_pretrained(baseline_name) |
| baseline_model = CLIPModelTransformers.from_pretrained(baseline_name).to(cfg.device) |
| baseline_model.eval() |
| print("Baseline model loaded.") |
|
|
| if "A" in selected_tests: |
| result_a = run_test_a( |
| model, |
| processor, |
| cfg, |
| num_examples=DEFAULT_NUM_EXAMPLES, |
| num_printed=DEFAULT_NUM_PRINTED, |
| ) |
| if baseline_model is not None and baseline_processor is not None: |
| baseline_result_a = run_test_a( |
| baseline_model, |
| baseline_processor, |
| cfg, |
| num_examples=DEFAULT_NUM_EXAMPLES, |
| num_printed=DEFAULT_NUM_PRINTED, |
| test_name="Baseline Test A", |
| ) |
| if "B" in selected_tests: |
| result_b = run_test_b( |
| model, |
| processor, |
| cfg, |
| num_examples=DEFAULT_NUM_EXAMPLES, |
| num_printed=DEFAULT_NUM_PRINTED, |
| ) |
| if baseline_model is not None and baseline_processor is not None: |
| baseline_result_b = run_test_b( |
| baseline_model, |
| baseline_processor, |
| cfg, |
| num_examples=DEFAULT_NUM_EXAMPLES, |
| num_printed=DEFAULT_NUM_PRINTED, |
| test_name="Baseline Test B", |
| ) |
| if "C" in selected_tests: |
| result_c = run_test_c( |
| model, |
| processor, |
| cfg, |
| num_examples=DEFAULT_NUM_EXAMPLES, |
| num_printed=DEFAULT_NUM_PRINTED, |
| ) |
| if baseline_model is not None and baseline_processor is not None: |
| baseline_result_c = run_test_c( |
| baseline_model, |
| baseline_processor, |
| cfg, |
| num_examples=DEFAULT_NUM_EXAMPLES, |
| num_printed=DEFAULT_NUM_PRINTED, |
| test_name="Baseline Test C", |
| ) |
|
|
| if "D" in selected_tests: |
| assert baseline_model is not None and baseline_processor is not None |
|
|
| print("\n" + "=" * 120) |
| print("Test D — Notebook-style zero-shot accuracy") |
| print("=" * 120) |
| d_results: Dict[str, Dict[str, Optional[Dict[str, float]]]] = { |
| "Fashion-MNIST": { |
| "gap": {"accuracy": zero_shot_fashion_mnist(model=model, processor=processor, device=cfg.device, batch_size=64)}, |
| "base": {"accuracy": zero_shot_fashion_mnist(model=baseline_model, processor=baseline_processor, device=cfg.device, batch_size=64)}, |
| }, |
| "KAGL Marqo": { |
| "gap": zero_shot_kagl(model=model, processor=processor, device=cfg.device, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES), |
| "base": zero_shot_kagl(model=baseline_model, processor=baseline_processor, device=cfg.device, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES), |
| }, |
| "Internal dataset": { |
| "gap": zero_shot_internal(model=model, processor=processor, device=cfg.device, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES), |
| "base": zero_shot_internal(model=baseline_model, processor=baseline_processor, device=cfg.device, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES), |
| }, |
| "ModaNet": { |
| "gap": zero_shot_modanet(model=model, processor=processor, device=cfg.device, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES, use_gap_labels=True), |
| "base": zero_shot_modanet(model=baseline_model, processor=baseline_processor, device=cfg.device, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES, use_gap_labels=True), |
| }, |
| } |
|
|
| print("\n" + "-" * 120) |
| print("Test D summary") |
| print("-" * 120) |
| summary_rows: List[List[str]] = [] |
| for ds in ["Fashion-MNIST", "KAGL Marqo", "ModaNet", "Internal dataset"]: |
| gap_result = d_results[ds]["gap"] |
| base_result = d_results[ds]["base"] |
| gap_acc = None if gap_result is None else gap_result.get("accuracy") |
| base_acc = None if base_result is None else base_result.get("accuracy") |
| summary_rows.append([ |
| ds, |
| f"{gap_acc:.2%}" if gap_acc is not None else "N/A", |
| f"{base_acc:.2%}" if base_acc is not None else "N/A", |
| ]) |
| print_table( |
| "Test D — zero-shot accuracy (notebook protocol)", |
| ["Dataset", "GAP-CLIP", "Fashion-CLIP (baseline)"], |
| summary_rows, |
| ) |
| print("\n" + "=" * 120) |
| print("Final Summary") |
| print("=" * 120) |
| print(f"Tests selected: {''.join(sorted(selected_tests))}") |
| if result_a is not None: |
| print(f"Test A overall: {format_bool(bool(result_a['overall']))}") |
| print(f"Test A full512 accuracy: {float(result_a['accuracy_full512']):.2%}") |
| if baseline_result_a is not None: |
| print(f"Baseline Test A full512 accuracy: {float(baseline_result_a['accuracy_full512']):.2%}") |
| if result_b is not None: |
| print(f"Test B overall: {format_bool(bool(result_b['overall']))}") |
| print(f"Test B full512 accuracy: {float(result_b['accuracy_full512']):.2%}") |
| if baseline_result_b is not None: |
| print(f"Baseline Test B full512 accuracy: {float(baseline_result_b['accuracy_full512']):.2%}") |
| if result_c is not None: |
| print(f"Test C overall: {format_bool(bool(result_c['overall']))}") |
| print(f" pass rate: {float(result_c['pass_rate']):.2%}") |
| print(f" avg color_match={float(result_c['avg_color_match']):.4f} vs cross={float(result_c['avg_color_cross']):.4f}") |
| print(f" avg hier_match={float(result_c['avg_hier_match']):.4f} vs cross={float(result_c['avg_hier_cross']):.4f}") |
| if baseline_result_c is not None: |
| print(f"Baseline Test C overall: {format_bool(bool(baseline_result_c['overall']))}") |
| print(f" baseline pass rate: {float(baseline_result_c['pass_rate']):.2%}") |
|
|
| if result_a is not None: |
| assert float(result_a["pass_rate"]) >= 0.95, ( |
| f"Test A failed: pass rate {float(result_a['pass_rate']):.2%} < 95%." |
| ) |
| if result_b is not None: |
| assert float(result_b["pass_rate"]) >= 0.95, ( |
| f"Test B failed: pass rate {float(result_b['pass_rate']):.2%} < 95%." |
| ) |
| if result_c is not None: |
| assert float(result_c["pass_rate"]) >= 0.95, ( |
| f"Test C failed: subspace decomposition pass rate {float(result_c['pass_rate']):.2%} < 95%." |
| ) |
|
|
| print("\nAll embedding-structure tests passed.") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Embedding structure evaluation") |
| parser.add_argument("--tests", default="ABCD", help="Which tests to run, e.g. 'C' or 'ABCD'") |
| parser.add_argument("--num-examples", type=int, default=None, help="Override DEFAULT_NUM_EXAMPLES") |
| args = parser.parse_args() |
| if args.num_examples is not None: |
| DEFAULT_NUM_EXAMPLES = args.num_examples |
| selected_tests = set(args.tests.upper()) |
| main(selected_tests) |
|
|