Leacb4 commited on
Commit
2dd8fc6
·
verified ·
1 Parent(s): 942267d

Upload evaluation/sec52_category_model_eval.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. evaluation/sec52_category_model_eval.py +158 -503
evaluation/sec52_category_model_eval.py CHANGED
@@ -28,18 +28,13 @@ import torch
28
  import pandas as pd
29
  import numpy as np
30
  import matplotlib.pyplot as plt
31
- import seaborn as sns
32
  import difflib
33
  from collections import defaultdict
34
- import hashlib
35
- from pathlib import Path
36
- import requests
37
 
38
  from sklearn.metrics.pairwise import cosine_similarity
39
- from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
40
  from sklearn.preprocessing import normalize
41
 
42
- from tqdm import tqdm
43
  from torch.utils.data import Dataset, DataLoader
44
  from torchvision import transforms
45
  from PIL import Image
@@ -48,178 +43,29 @@ from io import BytesIO
48
  import warnings
49
  warnings.filterwarnings('ignore')
50
 
51
- from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
52
-
53
  from config import (
 
54
  main_model_path,
 
55
  hierarchy_model_path,
56
  color_emb_dim,
57
  hierarchy_emb_dim,
58
  local_dataset_path,
59
  column_local_image_path,
60
- images_dir,
61
  )
62
 
63
- # ============================================================================
64
- # 1. Fashion-MNIST utilities
65
- # ============================================================================
66
-
67
- def get_fashion_mnist_labels():
68
- return {
69
- 0: "T-shirt/top",
70
- 1: "Trouser",
71
- 2: "Pullover",
72
- 3: "Dress",
73
- 4: "Coat",
74
- 5: "Sandal",
75
- 6: "Shirt",
76
- 7: "Sneaker",
77
- 8: "Bag",
78
- 9: "Ankle boot",
79
- }
80
-
81
-
82
- def create_fashion_mnist_to_hierarchy_mapping(hierarchy_classes):
83
- fashion_mnist_labels = get_fashion_mnist_labels()
84
- hierarchy_classes_lower = [h.lower() for h in hierarchy_classes]
85
- mapping = {}
86
-
87
- for fm_label_id, fm_label in fashion_mnist_labels.items():
88
- fm_label_lower = fm_label.lower()
89
- matched_hierarchy = None
90
-
91
- if fm_label_lower in hierarchy_classes_lower:
92
- matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(fm_label_lower)]
93
- elif any(h in fm_label_lower or fm_label_lower in h for h in hierarchy_classes_lower):
94
- for h_class in hierarchy_classes:
95
- h_lower = h_class.lower()
96
- if h_lower in fm_label_lower or fm_label_lower in h_lower:
97
- matched_hierarchy = h_class
98
- break
99
- else:
100
- if fm_label_lower in ['t-shirt/top', 'top']:
101
- if 'top' in hierarchy_classes_lower:
102
- matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('top')]
103
-
104
- elif 'trouser' in fm_label_lower:
105
- for possible in ['bottom', 'pants', 'trousers', 'trouser', 'pant']:
106
- if possible in hierarchy_classes_lower:
107
- matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)]
108
- break
109
-
110
- elif 'pullover' in fm_label_lower:
111
- for possible in ['sweater', 'pullover']:
112
- if possible in hierarchy_classes_lower:
113
- matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)]
114
- break
115
-
116
- elif 'dress' in fm_label_lower:
117
- if 'dress' in hierarchy_classes_lower:
118
- matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('dress')]
119
-
120
- elif 'coat' in fm_label_lower:
121
- for possible in ['jacket', 'outerwear', 'coat']:
122
- if possible in hierarchy_classes_lower:
123
- matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)]
124
- break
125
-
126
- elif fm_label_lower in ['sandal', 'sneaker', 'ankle boot']:
127
- for possible in ['shoes', 'shoe', 'sandal', 'sneaker', 'boot']:
128
- if possible in hierarchy_classes_lower:
129
- matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)]
130
- break
131
-
132
- elif 'bag' in fm_label_lower:
133
- if 'bag' in hierarchy_classes_lower:
134
- matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('bag')]
135
-
136
- if matched_hierarchy is None:
137
- close_matches = difflib.get_close_matches(
138
- fm_label_lower, hierarchy_classes_lower, n=1, cutoff=0.6
139
- )
140
- if close_matches:
141
- matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(close_matches[0])]
142
-
143
- mapping[fm_label_id] = matched_hierarchy
144
- if matched_hierarchy:
145
- print(f" {fm_label} ({fm_label_id}) -> {matched_hierarchy}")
146
- else:
147
- print(f" {fm_label} ({fm_label_id}) -> NO MATCH (will be filtered out)")
148
-
149
- return mapping
150
-
151
-
152
- def convert_fashion_mnist_to_image(pixel_values):
153
- image_array = np.array(pixel_values).reshape(28, 28).astype(np.uint8)
154
- image_array = np.stack([image_array] * 3, axis=-1)
155
- return Image.fromarray(image_array)
156
-
157
-
158
- class FashionMNISTDataset(Dataset):
159
- def __init__(self, dataframe, image_size=224, label_mapping=None):
160
- self.dataframe = dataframe
161
- self.image_size = image_size
162
- self.labels_map = get_fashion_mnist_labels()
163
- self.label_mapping = label_mapping
164
-
165
- self.transform = transforms.Compose([
166
- transforms.Resize((image_size, image_size)),
167
- transforms.ToTensor(),
168
- transforms.Normalize(
169
- mean=[0.485, 0.456, 0.406],
170
- std=[0.229, 0.224, 0.225],
171
- ),
172
- ])
173
-
174
- def __len__(self):
175
- return len(self.dataframe)
176
-
177
- def __getitem__(self, idx):
178
- row = self.dataframe.iloc[idx]
179
-
180
- pixel_cols = [f"pixel{i}" for i in range(1, 785)]
181
- pixel_values = row[pixel_cols].values
182
-
183
- image = convert_fashion_mnist_to_image(pixel_values)
184
- image = self.transform(image)
185
-
186
- label_id = int(row['label'])
187
- description = self.labels_map[label_id]
188
- color = "unknown"
189
-
190
- if self.label_mapping and label_id in self.label_mapping:
191
- hierarchy = self.label_mapping[label_id]
192
- else:
193
- hierarchy = self.labels_map[label_id]
194
-
195
- return image, description, color, hierarchy
196
-
197
-
198
- def load_fashion_mnist_dataset(
199
- max_samples=10000,
200
- hierarchy_classes=None,
201
- csv_path=None,
202
- ):
203
- if csv_path is None:
204
- csv_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data", "fashion-mnist_test.csv")
205
- print("Loading Fashion-MNIST test dataset...")
206
- df = pd.read_csv(csv_path)
207
- print(f"Fashion-MNIST dataset loaded: {len(df)} samples")
208
-
209
- label_mapping = None
210
- if hierarchy_classes is not None:
211
- print("\nCreating mapping from Fashion-MNIST labels to hierarchy classes:")
212
- label_mapping = create_fashion_mnist_to_hierarchy_mapping(hierarchy_classes)
213
-
214
- valid_label_ids = [lid for lid, h in label_mapping.items() if h is not None]
215
- df_filtered = df[df['label'].isin(valid_label_ids)]
216
- print(f"\nAfter filtering to mappable labels: {len(df_filtered)} samples (from {len(df)})")
217
- df_sample = df_filtered.head(max_samples)
218
- else:
219
- df_sample = df.head(max_samples)
220
-
221
- print(f"Using {len(df_sample)} samples for evaluation")
222
- return FashionMNISTDataset(df_sample, label_mapping=label_mapping)
223
 
224
 
225
  # ============================================================================
@@ -256,21 +102,24 @@ class KaggleHierarchyDataset(Dataset):
256
  return image, description, color, hierarchy
257
 
258
 
259
- def load_kaggle_marqo_with_hierarchy(max_samples=10000, hierarchy_classes=None):
260
- """Load KAGL Marqo dataset with hierarchy labels derived from articleType."""
261
- from datasets import load_dataset
262
 
263
- print("Loading KAGL Marqo dataset for hierarchy evaluation...")
264
- dataset = load_dataset("Marqo/KAGL")
265
- df = dataset["data"].to_pandas()
 
 
 
 
 
 
 
 
266
  print(f"Dataset loaded: {len(df)} samples, columns: {list(df.columns)}")
267
 
268
  # Use the most specific category column as hierarchy source
269
- hierarchy_col = None
270
- for col in ["articleType", "category3", "category2", "subCategory", "masterCategory", "category1"]:
271
- if col in df.columns:
272
- hierarchy_col = col
273
- break
274
 
275
  if hierarchy_col is None:
276
  print("WARNING: No hierarchy column found in KAGL dataset")
@@ -335,29 +184,10 @@ class LocalHierarchyDataset(Dataset):
335
  def __getitem__(self, idx):
336
  row = self.dataframe.iloc[idx]
337
  try:
338
- image_path = row.get(column_local_image_path) if hasattr(row, "get") else None
339
- if isinstance(image_path, str) and image_path and os.path.exists(image_path):
340
- image = Image.open(image_path).convert("RGB")
341
- else:
342
- # Fallback: download image from URL (and cache).
343
- image_url = row.get("image_url") if hasattr(row, "get") else None
344
- if isinstance(image_url, dict) and "bytes" in image_url:
345
- image = Image.open(BytesIO(image_url["bytes"])).convert("RGB")
346
- elif isinstance(image_url, str) and image_url:
347
- cache_dir = Path(images_dir)
348
- cache_dir.mkdir(parents=True, exist_ok=True)
349
- url_hash = hashlib.md5(image_url.encode("utf-8")).hexdigest()
350
- cache_path = cache_dir / f"{url_hash}.jpg"
351
- if cache_path.exists():
352
- image = Image.open(cache_path).convert("RGB")
353
- else:
354
- resp = requests.get(image_url, timeout=10)
355
- resp.raise_for_status()
356
- image = Image.open(BytesIO(resp.content)).convert("RGB")
357
- # Cache so repeated runs are faster.
358
- image.save(cache_path, "JPEG", quality=85, optimize=True)
359
- else:
360
- raise ValueError("Missing image_path and image_url")
361
  except Exception:
362
  image = Image.new("RGB", (224, 224), color="gray")
363
  image = self.transform(image)
@@ -367,18 +197,21 @@ class LocalHierarchyDataset(Dataset):
367
  return image, description, color, hierarchy
368
 
369
 
370
- def load_local_validation_with_hierarchy(max_samples=10000, hierarchy_classes=None):
371
- """Load internal validation dataset with hierarchy labels."""
372
- print("Loading local validation dataset for hierarchy evaluation...")
373
- df = pd.read_csv(local_dataset_path)
374
- print(f"Dataset loaded: {len(df)} samples")
375
 
376
- # Some internal CSVs only contain `image_url` (no `local_image_path`).
377
- # If so, we fall back to downloading images on-demand.
378
- if column_local_image_path in df.columns:
379
- df = df.dropna(subset=[column_local_image_path, "hierarchy"])
 
 
380
  else:
381
- df = df.dropna(subset=["hierarchy"])
 
 
 
 
382
  df["hierarchy"] = df["hierarchy"].astype(str).str.strip()
383
  df = df[df["hierarchy"].str.len() > 0]
384
 
@@ -410,25 +243,32 @@ class CategoryModelEvaluator:
410
  baseline Fashion-CLIP on Fashion-MNIST, KAGL Marqo, and internal datasets.
411
  """
412
 
413
- def __init__(self, device='mps', directory='figures/confusion_matrices/cm_hierarchy'):
414
- self.device = torch.device(device)
 
 
 
 
415
  self.directory = directory
 
 
416
  self.color_emb_dim = color_emb_dim
417
  self.hierarchy_emb_dim = hierarchy_emb_dim
 
 
418
  os.makedirs(self.directory, exist_ok=True)
419
 
420
- # --- load GAP-CLIP ---
421
- print(f"Loading GAP-CLIP model from {main_model_path}")
422
- if not os.path.exists(main_model_path):
423
- raise FileNotFoundError(f"GAP-CLIP model file {main_model_path} not found")
424
-
425
- print("Loading hierarchy classes from hierarchy model...")
426
- if not os.path.exists(hierarchy_model_path):
427
- raise FileNotFoundError(f"Hierarchy model file {hierarchy_model_path} not found")
428
-
429
- hierarchy_checkpoint = torch.load(hierarchy_model_path, map_location=self.device)
430
- self.hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', [])
431
- print(f"Found {len(self.hierarchy_classes)} hierarchy classes: {sorted(self.hierarchy_classes)}")
432
 
433
  self.validation_hierarchy_classes = self._load_validation_hierarchy_classes()
434
  if self.validation_hierarchy_classes:
@@ -438,21 +278,23 @@ class CategoryModelEvaluator:
438
  print("Unable to load validation hierarchy classes, falling back to hierarchy model classes.")
439
  self.validation_hierarchy_classes = self.hierarchy_classes
440
 
441
- checkpoint = torch.load(main_model_path, map_location=self.device)
442
- self.processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
443
- self.model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
444
- self.model.load_state_dict(checkpoint['model_state_dict'])
445
- self.model.to(self.device)
446
- self.model.eval()
447
- print("GAP-CLIP model loaded successfully")
448
-
449
- # --- baseline Fashion-CLIP ---
450
- print("Loading baseline Fashion-CLIP model...")
451
- patrick_model_name = "patrickjohncyh/fashion-clip"
452
- self.baseline_processor = CLIPProcessor.from_pretrained(patrick_model_name)
453
- self.baseline_model = CLIPModel_transformers.from_pretrained(patrick_model_name).to(self.device)
454
- self.baseline_model.eval()
455
- print("Baseline Fashion-CLIP model loaded successfully")
 
 
456
 
457
  # ------------------------------------------------------------------
458
  # helpers
@@ -506,196 +348,23 @@ class CategoryModelEvaluator:
506
  )
507
 
508
  # ------------------------------------------------------------------
509
- # embedding extraction GAP-CLIP
510
  # ------------------------------------------------------------------
511
  def extract_full_embeddings(self, dataloader, embedding_type='text', max_samples=10000):
512
  """Full 512D embeddings from GAP-CLIP (text or image)."""
513
- all_embeddings, all_colors, all_hierarchies = [], [], []
514
- sample_count = 0
515
-
516
- with torch.no_grad():
517
- for batch in tqdm(dataloader, desc=f"GAP-CLIP {embedding_type} embeddings"):
518
- if sample_count >= max_samples:
519
- break
520
- images, texts, colors, hierarchies = batch
521
- images = images.to(self.device).expand(-1, 3, -1, -1)
522
-
523
- text_inputs = self.processor(text=list(texts), padding=True, return_tensors="pt")
524
- text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
525
- outputs = self.model(**text_inputs, pixel_values=images)
526
-
527
- if embedding_type == 'image':
528
- emb = outputs.image_embeds
529
- else:
530
- emb = outputs.text_embeds
531
-
532
- all_embeddings.append(emb.cpu().numpy())
533
- all_colors.extend(colors)
534
- all_hierarchies.extend(hierarchies)
535
- sample_count += len(images)
536
-
537
- del images, text_inputs, outputs, emb
538
- if torch.cuda.is_available():
539
- torch.cuda.empty_cache()
540
-
541
- return np.vstack(all_embeddings), all_colors, all_hierarchies
542
 
543
- # ------------------------------------------------------------------
544
- # embedding extraction — baseline Fashion-CLIP
545
- # ------------------------------------------------------------------
546
  def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000):
547
  """L2-normalised embeddings from baseline Fashion-CLIP."""
548
- all_embeddings, all_colors, all_hierarchies = [], [], []
549
- sample_count = 0
550
-
551
- with torch.no_grad():
552
- for batch in tqdm(dataloader, desc=f"Baseline {embedding_type} embeddings"):
553
- if sample_count >= max_samples:
554
- break
555
- images, texts, colors, hierarchies = batch
556
-
557
- if embedding_type == 'text':
558
- inp = self.baseline_processor(
559
- text=list(texts), return_tensors="pt",
560
- padding=True, truncation=True, max_length=77,
561
- )
562
- inp = {k: v.to(self.device) for k, v in inp.items()}
563
- feats = self.baseline_model.get_text_features(**inp)
564
- feats = feats / feats.norm(dim=-1, keepdim=True)
565
- emb = feats
566
-
567
- elif embedding_type == 'image':
568
- pil_images = []
569
- for i in range(images.shape[0]):
570
- t = images[i]
571
- if t.min() < 0 or t.max() > 1:
572
- mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
573
- std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
574
- t = torch.clamp(t * std + mean, 0, 1)
575
- pil_images.append(transforms.ToPILImage()(t))
576
-
577
- inp = self.baseline_processor(images=pil_images, return_tensors="pt")
578
- inp = {k: v.to(self.device) for k, v in inp.items()}
579
- feats = self.baseline_model.get_image_features(**inp)
580
- feats = feats / feats.norm(dim=-1, keepdim=True)
581
- emb = feats
582
- else:
583
- inp = self.baseline_processor(
584
- text=list(texts), return_tensors="pt",
585
- padding=True, truncation=True, max_length=77,
586
- )
587
- inp = {k: v.to(self.device) for k, v in inp.items()}
588
- feats = self.baseline_model.get_text_features(**inp)
589
- feats = feats / feats.norm(dim=-1, keepdim=True)
590
- emb = feats
591
-
592
- all_embeddings.append(emb.cpu().numpy())
593
- all_colors.extend(colors)
594
- all_hierarchies.extend(hierarchies)
595
- sample_count += len(images)
596
-
597
- del emb
598
- if torch.cuda.is_available():
599
- torch.cuda.empty_cache()
600
-
601
- return np.vstack(all_embeddings), all_colors, all_hierarchies
602
-
603
- # ------------------------------------------------------------------
604
- # metrics
605
- # ------------------------------------------------------------------
606
- def compute_embedding_accuracy(self, embeddings, labels, similarities=None):
607
- n = len(embeddings)
608
- if n == 0:
609
- return 0.0
610
- if similarities is None:
611
- similarities = cosine_similarity(embeddings)
612
-
613
- correct = 0
614
- for i in range(n):
615
- sims = similarities[i].copy()
616
- sims[i] = -1.0
617
- nearest_neighbor_idx = int(np.argmax(sims))
618
- predicted = labels[nearest_neighbor_idx]
619
- if predicted == labels[i]:
620
- correct += 1
621
- return correct / n
622
-
623
- def compute_similarity_metrics(self, embeddings, labels):
624
- max_samples = min(5000, len(embeddings))
625
- if len(embeddings) > max_samples:
626
- indices = np.random.choice(len(embeddings), max_samples, replace=False)
627
- embeddings = embeddings[indices]
628
- labels = [labels[i] for i in indices]
629
-
630
- similarities = cosine_similarity(embeddings)
631
-
632
- label_groups = defaultdict(list)
633
- for i, label in enumerate(labels):
634
- label_groups[label].append(i)
635
-
636
- intra = []
637
- for _, idxs in label_groups.items():
638
- if len(idxs) > 1:
639
- for i in range(len(idxs)):
640
- for j in range(i + 1, len(idxs)):
641
- intra.append(similarities[idxs[i], idxs[j]])
642
-
643
- inter = []
644
- keys = list(label_groups.keys())
645
- for i in range(len(keys)):
646
- for j in range(i + 1, len(keys)):
647
- for idx1 in label_groups[keys[i]]:
648
- for idx2 in label_groups[keys[j]]:
649
- inter.append(similarities[idx1, idx2])
650
-
651
- nn_acc = self.compute_embedding_accuracy(embeddings, labels, similarities)
652
-
653
- return {
654
- 'intra_class_mean': float(np.mean(intra)) if intra else 0.0,
655
- 'inter_class_mean': float(np.mean(inter)) if inter else 0.0,
656
- 'separation_score': (float(np.mean(intra) - np.mean(inter))
657
- if intra and inter else 0.0),
658
- 'nn_accuracy': nn_acc,
659
- }
660
-
661
- def compute_centroid_accuracy(self, embeddings, labels):
662
- if len(embeddings) == 0:
663
- return 0.0
664
- emb_norm = normalize(embeddings, norm='l2')
665
- unique_labels = sorted(set(labels))
666
- centroids = {}
667
- for label in unique_labels:
668
- idx = [i for i, l in enumerate(labels) if l == label]
669
- centroids[label] = normalize([emb_norm[idx].mean(axis=0)], norm='l2')[0]
670
-
671
- correct = 0
672
- for i, emb in enumerate(emb_norm):
673
- best_sim, pred = -1, None
674
- for label, c in centroids.items():
675
- sim = cosine_similarity([emb], [c])[0][0]
676
- if sim > best_sim:
677
- best_sim, pred = sim, label
678
- if pred == labels[i]:
679
- correct += 1
680
- return correct / len(labels)
681
-
682
- def predict_labels_from_embeddings(self, embeddings, labels):
683
- emb_norm = normalize(embeddings, norm='l2')
684
- unique_labels = sorted(set(labels))
685
- centroids = {}
686
- for label in unique_labels:
687
- idx = [i for i, l in enumerate(labels) if l == label]
688
- centroids[label] = normalize([emb_norm[idx].mean(axis=0)], norm='l2')[0]
689
-
690
- preds = []
691
- for emb in emb_norm:
692
- best_sim, pred = -1, None
693
- for label, c in centroids.items():
694
- sim = cosine_similarity([emb], [c])[0][0]
695
- if sim > best_sim:
696
- best_sim, pred = sim, label
697
- preds.append(pred)
698
- return preds
699
 
700
  def predict_labels_nearest_neighbor(self, embeddings, labels):
701
  """
@@ -741,23 +410,6 @@ class CategoryModelEvaluator:
741
  # ------------------------------------------------------------------
742
  # confusion matrix & classification report
743
  # ------------------------------------------------------------------
744
- def create_confusion_matrix(self, true_labels, predicted_labels,
745
- title="Confusion Matrix", label_type="Label"):
746
- unique_labels = sorted(set(true_labels + predicted_labels))
747
- cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
748
- acc = accuracy_score(true_labels, predicted_labels)
749
-
750
- plt.figure(figsize=(10, 8))
751
- sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
752
- xticklabels=unique_labels, yticklabels=unique_labels)
753
- plt.title(f'{title}\nAccuracy: {acc:.3f} ({acc * 100:.1f}%)')
754
- plt.ylabel(f'True {label_type}')
755
- plt.xlabel(f'Predicted {label_type}')
756
- plt.xticks(rotation=45)
757
- plt.yticks(rotation=0)
758
- plt.tight_layout()
759
- return plt.gcf(), acc, cm
760
-
761
  def evaluate_classification_performance(self, embeddings, labels,
762
  embedding_type="Embeddings",
763
  label_type="Label",
@@ -765,14 +417,14 @@ class CategoryModelEvaluator:
765
  if method == "nn":
766
  preds = self.predict_labels_nearest_neighbor(embeddings, labels)
767
  elif method == "centroid":
768
- preds = self.predict_labels_from_embeddings(embeddings, labels)
769
  else:
770
  raise ValueError(f"Unknown classification method: {method}")
771
  acc = accuracy_score(labels, preds)
772
  unique_labels = sorted(set(labels))
773
- fig, _, cm = self.create_confusion_matrix(
774
  labels, preds,
775
- embedding_type,
776
  label_type,
777
  )
778
  report = classification_report(labels, preds, labels=unique_labels,
@@ -786,6 +438,15 @@ class CategoryModelEvaluator:
786
  'figure': fig,
787
  }
788
 
 
 
 
 
 
 
 
 
 
789
  # ==================================================================
790
  # 3. GAP-CLIP evaluation on Fashion-MNIST
791
  # ==================================================================
@@ -824,10 +485,10 @@ class CategoryModelEvaluator:
824
  text_hier_spec = text_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim]
825
  print(f" Specialized text hierarchy shape: {text_hier_spec.shape}")
826
 
827
- text_metrics = self.compute_similarity_metrics(text_hier_spec, text_hier)
828
  text_class = self.evaluate_classification_performance(
829
  text_hier_spec, text_hier,
830
- "Fashion-MNIST, text, hierarchy confusion matrix", "Hierarchy",
831
  method="nn",
832
  )
833
  text_metrics.update(text_class)
@@ -839,18 +500,18 @@ class CategoryModelEvaluator:
839
  print(f" Specialized image hierarchy shape: {img_hier_spec.shape}")
840
 
841
  print(" Testing specialized 64D...")
842
- spec_metrics = self.compute_similarity_metrics(img_hier_spec, img_hier)
843
  spec_class = self.evaluate_classification_performance(
844
  img_hier_spec, img_hier,
845
- "Fashion-MNIST, image, hierarchy confusion matrix", "Hierarchy",
846
  method="nn",
847
  )
848
 
849
  print(" Testing full 512D...")
850
- full_metrics = self.compute_similarity_metrics(img_full, img_hier)
851
  full_class = self.evaluate_classification_performance(
852
  img_full, img_hier,
853
- "Fashion-MNIST, image, hierarchy confusion matrix", "Hierarchy",
854
  method="nn",
855
  )
856
 
@@ -889,6 +550,11 @@ class CategoryModelEvaluator:
889
  os.path.join(self.directory, f"gap_clip_{key}_confusion_matrix.png"),
890
  dpi=300, bbox_inches='tight',
891
  )
 
 
 
 
 
892
  plt.close(fig)
893
 
894
  del text_full, img_full, text_hier_spec, img_hier_spec
@@ -920,10 +586,10 @@ class CategoryModelEvaluator:
920
  self._validate_label_distribution(text_hier, expected_counts, "baseline text")
921
  print(f" Baseline text shape: {text_emb.shape}")
922
 
923
- text_metrics = self.compute_similarity_metrics(text_emb, text_hier)
924
  text_class = self.evaluate_classification_performance(
925
  text_emb, text_hier,
926
- "Fashion-MNIST, text, hierarchy confusion matrix", "Hierarchy",
927
  method="nn",
928
  )
929
  text_metrics.update(text_class)
@@ -939,10 +605,10 @@ class CategoryModelEvaluator:
939
  self._validate_label_distribution(img_hier, expected_counts, "baseline image")
940
  print(f" Baseline image shape: {img_emb.shape}")
941
 
942
- img_metrics = self.compute_similarity_metrics(img_emb, img_hier)
943
  img_class = self.evaluate_classification_performance(
944
  img_emb, img_hier,
945
- "Fashion-MNIST, image, hierarchy confusion matrix", "Hierarchy",
946
  method="nn",
947
  )
948
  img_metrics.update(img_class)
@@ -958,6 +624,11 @@ class CategoryModelEvaluator:
958
  os.path.join(self.directory, f"baseline_{key}_hierarchy_confusion_matrix.png"),
959
  dpi=300, bbox_inches='tight',
960
  )
 
 
 
 
 
961
  plt.close(fig)
962
 
963
  return results
@@ -980,10 +651,10 @@ class CategoryModelEvaluator:
980
  text_hier_spec = text_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim]
981
  print(f" Text shape: {text_full.shape}, hierarchy subspace: {text_hier_spec.shape}")
982
 
983
- text_metrics = self.compute_similarity_metrics(text_hier_spec, text_hier)
984
  text_class = self.evaluate_classification_performance(
985
  text_hier_spec, text_hier,
986
- f"{dataset_name}, text, hierarchy confusion matrix", "Hierarchy", method="nn",
987
  )
988
  text_metrics.update(text_class)
989
  results['text_hierarchy'] = text_metrics
@@ -993,16 +664,16 @@ class CategoryModelEvaluator:
993
  img_full, _, img_hier = self.extract_full_embeddings(dataloader, 'image', max_samples)
994
  img_hier_spec = img_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim]
995
 
996
- spec_metrics = self.compute_similarity_metrics(img_hier_spec, img_hier)
997
  spec_class = self.evaluate_classification_performance(
998
  img_hier_spec, img_hier,
999
- f"{dataset_name}, image, hierarchy confusion matrix", "Hierarchy", method="nn",
1000
  )
1001
 
1002
- full_metrics = self.compute_similarity_metrics(img_full, img_hier)
1003
  full_class = self.evaluate_classification_performance(
1004
  img_full, img_hier,
1005
- f"{dataset_name}, image, hierarchy confusion matrix", "Hierarchy", method="nn",
1006
  )
1007
 
1008
  if full_class['accuracy'] >= spec_class['accuracy']:
@@ -1023,6 +694,10 @@ class CategoryModelEvaluator:
1023
  os.path.join(self.directory, f"gap_clip_{prefix}_{key}_confusion_matrix.png"),
1024
  dpi=300, bbox_inches='tight',
1025
  )
 
 
 
 
1026
  plt.close(fig)
1027
 
1028
  del text_full, img_full, text_hier_spec, img_hier_spec
@@ -1044,10 +719,10 @@ class CategoryModelEvaluator:
1044
  text_emb, _, text_hier = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples)
1045
  print(f" Baseline text shape: {text_emb.shape}")
1046
 
1047
- text_metrics = self.compute_similarity_metrics(text_emb, text_hier)
1048
  text_class = self.evaluate_classification_performance(
1049
  text_emb, text_hier,
1050
- f"{dataset_name}, text, hierarchy confusion matrix", "Hierarchy", method="nn",
1051
  )
1052
  text_metrics.update(text_class)
1053
  results['text'] = {'hierarchy': text_metrics}
@@ -1061,10 +736,10 @@ class CategoryModelEvaluator:
1061
  img_emb, _, img_hier = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples)
1062
  print(f" Baseline image shape: {img_emb.shape}")
1063
 
1064
- img_metrics = self.compute_similarity_metrics(img_emb, img_hier)
1065
  img_class = self.evaluate_classification_performance(
1066
  img_emb, img_hier,
1067
- f"{dataset_name}, image, hierarchy confusion matrix", "Hierarchy", method="nn",
1068
  )
1069
  img_metrics.update(img_class)
1070
  results['image'] = {'hierarchy': img_metrics}
@@ -1080,6 +755,11 @@ class CategoryModelEvaluator:
1080
  os.path.join(self.directory, f"baseline_{prefix}_{key}_hierarchy_confusion_matrix.png"),
1081
  dpi=300, bbox_inches='tight',
1082
  )
 
 
 
 
 
1083
  plt.close(fig)
1084
 
1085
  return results
@@ -1087,10 +767,8 @@ class CategoryModelEvaluator:
1087
  # ==================================================================
1088
  # 6. Full evaluation across all datasets
1089
  # ==================================================================
1090
- def run_full_evaluation(self, max_samples=10000, local_max_samples=None, batch_size=8):
1091
  """Run hierarchy evaluation on all 3 datasets for both models."""
1092
- if local_max_samples is None:
1093
- local_max_samples = max_samples
1094
  all_results = {}
1095
 
1096
  # --- Fashion-MNIST ---
@@ -1109,6 +787,7 @@ class CategoryModelEvaluator:
1109
  kaggle_dataset = load_kaggle_marqo_with_hierarchy(
1110
  max_samples=max_samples,
1111
  hierarchy_classes=self.validation_hierarchy_classes or self.hierarchy_classes,
 
1112
  )
1113
  if kaggle_dataset is not None and len(kaggle_dataset) > 0:
1114
  kaggle_dataloader = DataLoader(kaggle_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
@@ -1126,16 +805,17 @@ class CategoryModelEvaluator:
1126
  # --- Internal (local validation) ---
1127
  try:
1128
  local_dataset = load_local_validation_with_hierarchy(
1129
- max_samples=local_max_samples,
1130
  hierarchy_classes=self.validation_hierarchy_classes or self.hierarchy_classes,
 
1131
  )
1132
  if local_dataset is not None and len(local_dataset) > 0:
1133
  local_dataloader = DataLoader(local_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
1134
  all_results['local_gap'] = self.evaluate_gap_clip_generic(
1135
- local_dataloader, "Internal", local_max_samples,
1136
  )
1137
  all_results['local_baseline'] = self.evaluate_baseline_generic(
1138
- local_dataloader, "Internal", local_max_samples,
1139
  )
1140
  else:
1141
  print("WARNING: Local validation dataset empty after hierarchy filtering, skipping.")
@@ -1161,13 +841,13 @@ class CategoryModelEvaluator:
1161
  if 'text_hierarchy' in res:
1162
  t = res['text_hierarchy']
1163
  i = res['image_hierarchy']
1164
- print(f" Text NN Acc: {t['nn_accuracy']*100:.1f}% | Separation: {t['separation_score']:.4f}")
1165
- print(f" Image NN Acc: {i['nn_accuracy']*100:.1f}% | Separation: {i['separation_score']:.4f}")
1166
  elif 'text' in res:
1167
  t = res['text']['hierarchy']
1168
  i = res['image']['hierarchy']
1169
- print(f" Text NN Acc: {t['nn_accuracy']*100:.1f}% | Separation: {t['separation_score']:.4f}")
1170
- print(f" Image NN Acc: {i['nn_accuracy']*100:.1f}% | Separation: {i['separation_score']:.4f}")
1171
 
1172
  return all_results
1173
 
@@ -1180,33 +860,8 @@ if __name__ == "__main__":
1180
  device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
1181
  print(f"Using device: {device}")
1182
 
1183
- directory = 'figures/confusion_matrices/cm_hierarchy'
1184
  max_samples = 10000
1185
- local_max_samples = 1000
1186
 
1187
  evaluator = CategoryModelEvaluator(device=device, directory=directory)
1188
-
1189
- # # Full evaluation including Fashion-MNIST and KAGL Marqo (skipped — CMs already generated)
1190
- # evaluator.run_full_evaluation(max_samples=max_samples, local_max_samples=local_max_samples, batch_size=8)
1191
-
1192
- # Evaluate only the local/internal dataset
1193
- local_dataset = load_local_validation_with_hierarchy(
1194
- max_samples=local_max_samples,
1195
- hierarchy_classes=evaluator.validation_hierarchy_classes or evaluator.hierarchy_classes,
1196
- )
1197
- if local_dataset is not None and len(local_dataset) > 0:
1198
- local_dl = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0)
1199
- results_gap = evaluator.evaluate_gap_clip_generic(local_dl, "Internal", local_max_samples)
1200
- results_base = evaluator.evaluate_baseline_generic(local_dl, "Internal", local_max_samples)
1201
-
1202
- print(f"\n{'=' * 60}")
1203
- print("INTERNAL DATASET — HIERARCHY EVALUATION SUMMARY")
1204
- print(f"{'=' * 60}")
1205
- print(f"\nGAP-CLIP:")
1206
- print(f" Text NN Acc: {results_gap['text_hierarchy']['nn_accuracy']*100:.1f}% | Separation: {results_gap['text_hierarchy']['separation_score']:.4f}")
1207
- print(f" Image NN Acc: {results_gap['image_hierarchy']['nn_accuracy']*100:.1f}% | Separation: {results_gap['image_hierarchy']['separation_score']:.4f}")
1208
- print(f"\nBaseline:")
1209
- print(f" Text NN Acc: {results_base['text']['hierarchy']['nn_accuracy']*100:.1f}% | Separation: {results_base['text']['hierarchy']['separation_score']:.4f}")
1210
- print(f" Image NN Acc: {results_base['image']['hierarchy']['nn_accuracy']*100:.1f}% | Separation: {results_base['image']['hierarchy']['separation_score']:.4f}")
1211
- else:
1212
- print("WARNING: Local validation dataset empty after hierarchy filtering.")
 
28
  import pandas as pd
29
  import numpy as np
30
  import matplotlib.pyplot as plt
 
31
  import difflib
32
  from collections import defaultdict
 
 
 
33
 
34
  from sklearn.metrics.pairwise import cosine_similarity
35
+ from sklearn.metrics import classification_report, accuracy_score
36
  from sklearn.preprocessing import normalize
37
 
 
38
  from torch.utils.data import Dataset, DataLoader
39
  from torchvision import transforms
40
  from PIL import Image
 
43
  import warnings
44
  warnings.filterwarnings('ignore')
45
 
 
 
46
  from config import (
47
+ ROOT_DIR,
48
  main_model_path,
49
+ main_emb_dim,
50
  hierarchy_model_path,
51
  color_emb_dim,
52
  hierarchy_emb_dim,
53
  local_dataset_path,
54
  column_local_image_path,
 
55
  )
56
 
57
+ from utils.datasets import (
58
+ load_fashion_mnist_dataset,
59
+ )
60
+ from utils.embeddings import extract_clip_embeddings
61
+ from utils.metrics import (
62
+ compute_similarity_metrics,
63
+ compute_embedding_accuracy,
64
+ compute_centroid_accuracy,
65
+ predict_labels_from_embeddings,
66
+ create_confusion_matrix,
67
+ )
68
+ from utils.model_loader import load_gap_clip, load_baseline_fashion_clip
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
  # ============================================================================
 
102
  return image, description, color, hierarchy
103
 
104
 
105
+ def load_kaggle_marqo_with_hierarchy(max_samples=10000, hierarchy_classes=None, raw_df=None):
106
+ """Load KAGL Marqo dataset with hierarchy labels derived from articleType.
 
107
 
108
+ Args:
109
+ raw_df: Pre-downloaded DataFrame to skip the HuggingFace download.
110
+ """
111
+ if raw_df is not None:
112
+ df = raw_df.copy()
113
+ print(f"Using cached KAGL DataFrame for hierarchy evaluation: {len(df)} samples")
114
+ else:
115
+ from datasets import load_dataset
116
+ print("Loading KAGL Marqo dataset for hierarchy evaluation...")
117
+ dataset = load_dataset("Marqo/KAGL")
118
+ df = dataset["data"].to_pandas()
119
  print(f"Dataset loaded: {len(df)} samples, columns: {list(df.columns)}")
120
 
121
  # Use the most specific category column as hierarchy source
122
+ hierarchy_col = 'category2'
 
 
 
 
123
 
124
  if hierarchy_col is None:
125
  print("WARNING: No hierarchy column found in KAGL dataset")
 
184
  def __getitem__(self, idx):
185
  row = self.dataframe.iloc[idx]
186
  try:
187
+ img_path = row[column_local_image_path]
188
+ if not os.path.isabs(img_path):
189
+ img_path = os.path.join(ROOT_DIR, img_path)
190
+ image = Image.open(img_path).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  except Exception:
192
  image = Image.new("RGB", (224, 224), color="gray")
193
  image = self.transform(image)
 
197
  return image, description, color, hierarchy
198
 
199
 
200
+ def load_local_validation_with_hierarchy(max_samples=10000, hierarchy_classes=None, raw_df=None):
201
+ """Load internal validation dataset with hierarchy labels.
 
 
 
202
 
203
+ Args:
204
+ raw_df: Pre-loaded DataFrame to skip CSV read.
205
+ """
206
+ if raw_df is not None:
207
+ df = raw_df.copy()
208
+ print(f"Using cached local DataFrame for hierarchy evaluation: {len(df)} samples")
209
  else:
210
+ print("Loading local validation dataset for hierarchy evaluation...")
211
+ df = pd.read_csv(local_dataset_path)
212
+ print(f"Dataset loaded: {len(df)} samples")
213
+
214
+ df = df.dropna(subset=[column_local_image_path, "hierarchy"])
215
  df["hierarchy"] = df["hierarchy"].astype(str).str.strip()
216
  df = df[df["hierarchy"].str.len() > 0]
217
 
 
243
  baseline Fashion-CLIP on Fashion-MNIST, KAGL Marqo, and internal datasets.
244
  """
245
 
246
+ def __init__(self, device='mps', directory='gap_clip_confusion_matrices',
247
+ gap_clip_model=None, gap_clip_processor=None,
248
+ baseline_model=None, baseline_processor=None,
249
+ hierarchy_classes=None,
250
+ kaggle_raw_df=None, local_raw_df=None):
251
+ self.device = torch.device(device) if isinstance(device, str) else device
252
  self.directory = directory
253
+ self.kaggle_raw_df = kaggle_raw_df
254
+ self.local_raw_df = local_raw_df
255
  self.color_emb_dim = color_emb_dim
256
  self.hierarchy_emb_dim = hierarchy_emb_dim
257
+ self.main_emb_dim = main_emb_dim
258
+ self.hierarchy_end_dim = self.color_emb_dim + self.hierarchy_emb_dim
259
  os.makedirs(self.directory, exist_ok=True)
260
 
261
+ # --- hierarchy classes ---
262
+ if hierarchy_classes is not None:
263
+ self.hierarchy_classes = hierarchy_classes
264
+ print(f"Using provided hierarchy classes: {len(self.hierarchy_classes)} classes")
265
+ else:
266
+ print("Loading hierarchy classes from hierarchy model...")
267
+ if not os.path.exists(hierarchy_model_path):
268
+ raise FileNotFoundError(f"Hierarchy model file {hierarchy_model_path} not found")
269
+ hierarchy_checkpoint = torch.load(hierarchy_model_path, map_location=self.device)
270
+ self.hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', [])
271
+ print(f"Found {len(self.hierarchy_classes)} hierarchy classes: {sorted(self.hierarchy_classes)}")
 
272
 
273
  self.validation_hierarchy_classes = self._load_validation_hierarchy_classes()
274
  if self.validation_hierarchy_classes:
 
278
  print("Unable to load validation hierarchy classes, falling back to hierarchy model classes.")
279
  self.validation_hierarchy_classes = self.hierarchy_classes
280
 
281
+ # --- load GAP-CLIP (accept pre-loaded or load from scratch) ---
282
+ if gap_clip_model is not None and gap_clip_processor is not None:
283
+ self.model = gap_clip_model
284
+ self.processor = gap_clip_processor
285
+ print("Using pre-loaded GAP-CLIP model")
286
+ else:
287
+ self.model, self.processor = load_gap_clip(main_model_path, self.device)
288
+ print("GAP-CLIP model loaded successfully")
289
+
290
+ # --- baseline Fashion-CLIP (accept pre-loaded or load from scratch) ---
291
+ if baseline_model is not None and baseline_processor is not None:
292
+ self.baseline_model = baseline_model
293
+ self.baseline_processor = baseline_processor
294
+ print("Using pre-loaded baseline Fashion-CLIP model")
295
+ else:
296
+ self.baseline_model, self.baseline_processor = load_baseline_fashion_clip(self.device)
297
+ print("Baseline Fashion-CLIP model loaded successfully")
298
 
299
  # ------------------------------------------------------------------
300
  # helpers
 
348
  )
349
 
350
  # ------------------------------------------------------------------
351
+ # embedding extraction (delegates to shared utils)
352
  # ------------------------------------------------------------------
353
  def extract_full_embeddings(self, dataloader, embedding_type='text', max_samples=10000):
354
  """Full 512D embeddings from GAP-CLIP (text or image)."""
355
+ return extract_clip_embeddings(
356
+ self.model, self.processor, dataloader, self.device,
357
+ embedding_type=embedding_type, max_samples=max_samples,
358
+ desc=f"GAP-CLIP {embedding_type} embeddings",
359
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
 
 
 
361
  def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000):
362
  """L2-normalised embeddings from baseline Fashion-CLIP."""
363
+ return extract_clip_embeddings(
364
+ self.baseline_model, self.baseline_processor, dataloader, self.device,
365
+ embedding_type=embedding_type, max_samples=max_samples,
366
+ desc=f"Baseline {embedding_type} embeddings",
367
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
 
369
  def predict_labels_nearest_neighbor(self, embeddings, labels):
370
  """
 
410
  # ------------------------------------------------------------------
411
  # confusion matrix & classification report
412
  # ------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
  def evaluate_classification_performance(self, embeddings, labels,
414
  embedding_type="Embeddings",
415
  label_type="Label",
 
417
  if method == "nn":
418
  preds = self.predict_labels_nearest_neighbor(embeddings, labels)
419
  elif method == "centroid":
420
+ preds = predict_labels_from_embeddings(embeddings, labels)
421
  else:
422
  raise ValueError(f"Unknown classification method: {method}")
423
  acc = accuracy_score(labels, preds)
424
  unique_labels = sorted(set(labels))
425
+ fig, _, cm = create_confusion_matrix(
426
  labels, preds,
427
+ f"{embedding_type} - {label_type} Classification ({method.upper()})",
428
  label_type,
429
  )
430
  report = classification_report(labels, preds, labels=unique_labels,
 
438
  'figure': fig,
439
  }
440
 
441
+ def save_confusion_matrix_table(self, cm, labels, output_csv_path):
442
+ """
443
+ Save confusion matrix values with per-row totals to CSV for auditing.
444
+ """
445
+ cm_df = pd.DataFrame(cm, index=labels, columns=labels)
446
+ cm_df["row_total"] = cm_df.sum(axis=1)
447
+ cm_df.loc["column_total"] = list(cm_df[labels].sum(axis=0)) + [cm_df["row_total"].sum()]
448
+ cm_df.to_csv(output_csv_path)
449
+
450
  # ==================================================================
451
  # 3. GAP-CLIP evaluation on Fashion-MNIST
452
  # ==================================================================
 
485
  text_hier_spec = text_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim]
486
  print(f" Specialized text hierarchy shape: {text_hier_spec.shape}")
487
 
488
+ text_metrics = compute_similarity_metrics(text_hier_spec, text_hier)
489
  text_class = self.evaluate_classification_performance(
490
  text_hier_spec, text_hier,
491
+ "GAP-CLIP Text Hierarchy (64D)", "Hierarchy",
492
  method="nn",
493
  )
494
  text_metrics.update(text_class)
 
500
  print(f" Specialized image hierarchy shape: {img_hier_spec.shape}")
501
 
502
  print(" Testing specialized 64D...")
503
+ spec_metrics = compute_similarity_metrics(img_hier_spec, img_hier)
504
  spec_class = self.evaluate_classification_performance(
505
  img_hier_spec, img_hier,
506
+ "GAP-CLIP Image Hierarchy (64D)", "Hierarchy",
507
  method="nn",
508
  )
509
 
510
  print(" Testing full 512D...")
511
+ full_metrics = compute_similarity_metrics(img_full, img_hier)
512
  full_class = self.evaluate_classification_performance(
513
  img_full, img_hier,
514
+ "GAP-CLIP Image Hierarchy (512D full)", "Hierarchy",
515
  method="nn",
516
  )
517
 
 
550
  os.path.join(self.directory, f"gap_clip_{key}_confusion_matrix.png"),
551
  dpi=300, bbox_inches='tight',
552
  )
553
+ self.save_confusion_matrix_table(
554
+ results[key]['confusion_matrix'],
555
+ results[key]['labels'],
556
+ os.path.join(self.directory, f"gap_clip_{key}_confusion_matrix.csv"),
557
+ )
558
  plt.close(fig)
559
 
560
  del text_full, img_full, text_hier_spec, img_hier_spec
 
586
  self._validate_label_distribution(text_hier, expected_counts, "baseline text")
587
  print(f" Baseline text shape: {text_emb.shape}")
588
 
589
+ text_metrics = compute_similarity_metrics(text_emb, text_hier)
590
  text_class = self.evaluate_classification_performance(
591
  text_emb, text_hier,
592
+ "Baseline Fashion-CLIP Text - Hierarchy", "Hierarchy",
593
  method="nn",
594
  )
595
  text_metrics.update(text_class)
 
605
  self._validate_label_distribution(img_hier, expected_counts, "baseline image")
606
  print(f" Baseline image shape: {img_emb.shape}")
607
 
608
+ img_metrics = compute_similarity_metrics(img_emb, img_hier)
609
  img_class = self.evaluate_classification_performance(
610
  img_emb, img_hier,
611
+ "Baseline Fashion-CLIP Image - Hierarchy", "Hierarchy",
612
  method="nn",
613
  )
614
  img_metrics.update(img_class)
 
624
  os.path.join(self.directory, f"baseline_{key}_hierarchy_confusion_matrix.png"),
625
  dpi=300, bbox_inches='tight',
626
  )
627
+ self.save_confusion_matrix_table(
628
+ results[key]['hierarchy']['confusion_matrix'],
629
+ results[key]['hierarchy']['labels'],
630
+ os.path.join(self.directory, f"baseline_{key}_hierarchy_confusion_matrix.csv"),
631
+ )
632
  plt.close(fig)
633
 
634
  return results
 
651
  text_hier_spec = text_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim]
652
  print(f" Text shape: {text_full.shape}, hierarchy subspace: {text_hier_spec.shape}")
653
 
654
+ text_metrics = compute_similarity_metrics(text_hier_spec, text_hier)
655
  text_class = self.evaluate_classification_performance(
656
  text_hier_spec, text_hier,
657
+ f"GAP-CLIP Text Hierarchy {dataset_name}", "Hierarchy", method="nn",
658
  )
659
  text_metrics.update(text_class)
660
  results['text_hierarchy'] = text_metrics
 
664
  img_full, _, img_hier = self.extract_full_embeddings(dataloader, 'image', max_samples)
665
  img_hier_spec = img_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim]
666
 
667
+ spec_metrics = compute_similarity_metrics(img_hier_spec, img_hier)
668
  spec_class = self.evaluate_classification_performance(
669
  img_hier_spec, img_hier,
670
+ f"GAP-CLIP Image Hierarchy (64D) – {dataset_name}", "Hierarchy", method="nn",
671
  )
672
 
673
+ full_metrics = compute_similarity_metrics(img_full, img_hier)
674
  full_class = self.evaluate_classification_performance(
675
  img_full, img_hier,
676
+ f"GAP-CLIP Image Hierarchy (512D) – {dataset_name}", "Hierarchy", method="nn",
677
  )
678
 
679
  if full_class['accuracy'] >= spec_class['accuracy']:
 
694
  os.path.join(self.directory, f"gap_clip_{prefix}_{key}_confusion_matrix.png"),
695
  dpi=300, bbox_inches='tight',
696
  )
697
+ self.save_confusion_matrix_table(
698
+ results[key]['confusion_matrix'], results[key]['labels'],
699
+ os.path.join(self.directory, f"gap_clip_{prefix}_{key}_confusion_matrix.csv"),
700
+ )
701
  plt.close(fig)
702
 
703
  del text_full, img_full, text_hier_spec, img_hier_spec
 
719
  text_emb, _, text_hier = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples)
720
  print(f" Baseline text shape: {text_emb.shape}")
721
 
722
+ text_metrics = compute_similarity_metrics(text_emb, text_hier)
723
  text_class = self.evaluate_classification_performance(
724
  text_emb, text_hier,
725
+ f"Baseline Text Hierarchy {dataset_name}", "Hierarchy", method="nn",
726
  )
727
  text_metrics.update(text_class)
728
  results['text'] = {'hierarchy': text_metrics}
 
736
  img_emb, _, img_hier = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples)
737
  print(f" Baseline image shape: {img_emb.shape}")
738
 
739
+ img_metrics = compute_similarity_metrics(img_emb, img_hier)
740
  img_class = self.evaluate_classification_performance(
741
  img_emb, img_hier,
742
+ f"Baseline Image Hierarchy {dataset_name}", "Hierarchy", method="nn",
743
  )
744
  img_metrics.update(img_class)
745
  results['image'] = {'hierarchy': img_metrics}
 
755
  os.path.join(self.directory, f"baseline_{prefix}_{key}_hierarchy_confusion_matrix.png"),
756
  dpi=300, bbox_inches='tight',
757
  )
758
+ self.save_confusion_matrix_table(
759
+ results[key]['hierarchy']['confusion_matrix'],
760
+ results[key]['hierarchy']['labels'],
761
+ os.path.join(self.directory, f"baseline_{prefix}_{key}_hierarchy_confusion_matrix.csv"),
762
+ )
763
  plt.close(fig)
764
 
765
  return results
 
767
  # ==================================================================
768
  # 6. Full evaluation across all datasets
769
  # ==================================================================
770
+ def run_full_evaluation(self, max_samples=10000, batch_size=8):
771
  """Run hierarchy evaluation on all 3 datasets for both models."""
 
 
772
  all_results = {}
773
 
774
  # --- Fashion-MNIST ---
 
787
  kaggle_dataset = load_kaggle_marqo_with_hierarchy(
788
  max_samples=max_samples,
789
  hierarchy_classes=self.validation_hierarchy_classes or self.hierarchy_classes,
790
+ raw_df=self.kaggle_raw_df,
791
  )
792
  if kaggle_dataset is not None and len(kaggle_dataset) > 0:
793
  kaggle_dataloader = DataLoader(kaggle_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
 
805
  # --- Internal (local validation) ---
806
  try:
807
  local_dataset = load_local_validation_with_hierarchy(
808
+ max_samples=max_samples,
809
  hierarchy_classes=self.validation_hierarchy_classes or self.hierarchy_classes,
810
+ raw_df=self.local_raw_df,
811
  )
812
  if local_dataset is not None and len(local_dataset) > 0:
813
  local_dataloader = DataLoader(local_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
814
  all_results['local_gap'] = self.evaluate_gap_clip_generic(
815
+ local_dataloader, "Internal", max_samples,
816
  )
817
  all_results['local_baseline'] = self.evaluate_baseline_generic(
818
+ local_dataloader, "Internal", max_samples,
819
  )
820
  else:
821
  print("WARNING: Local validation dataset empty after hierarchy filtering, skipping.")
 
841
  if 'text_hierarchy' in res:
842
  t = res['text_hierarchy']
843
  i = res['image_hierarchy']
844
+ print(f" Text NN Acc: {t['accuracy']*100:.1f}% | Separation: {t['separation_score']:.4f}")
845
+ print(f" Image NN Acc: {i['accuracy']*100:.1f}% | Separation: {i['separation_score']:.4f}")
846
  elif 'text' in res:
847
  t = res['text']['hierarchy']
848
  i = res['image']['hierarchy']
849
+ print(f" Text NN Acc: {t['accuracy']*100:.1f}% | Separation: {t['separation_score']:.4f}")
850
+ print(f" Image NN Acc: {i['accuracy']*100:.1f}% | Separation: {i['separation_score']:.4f}")
851
 
852
  return all_results
853
 
 
860
  device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
861
  print(f"Using device: {device}")
862
 
863
+ directory = 'gap_clip_confusion_matrices'
864
  max_samples = 10000
 
865
 
866
  evaluator = CategoryModelEvaluator(device=device, directory=directory)
867
+ evaluator.run_full_evaluation(max_samples=max_samples, batch_size=8)