AbstractPhil commited on
Commit
5624568
·
verified ·
1 Parent(s): b155f4d

Create probe.py

Browse files
Files changed (1) hide show
  1. code/probe.py +524 -0
code/probe.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Comprehensive testing cell for BaselineViT (RoseFace-aware)
3
+ Run AFTER loading your model & checkpoint in Colab.
4
+ Assumes: model, get_cifar100_dataloaders are defined.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+ from sklearn.manifold import TSNE
12
+ from sklearn.decomposition import PCA
13
+ from sklearn.cluster import KMeans
14
+ from sklearn.metrics import silhouette_score
15
+ from pathlib import Path
16
+ import json
17
+ from tqdm import tqdm
18
+
19
+ # =========================
20
+ # RoseFace-aware utilities
21
+ # =========================
22
+
23
+ @torch.no_grad()
24
+ def principal_angle_overlap(class_pentachora):
25
+ """
26
+ Measure subspace overlap between classes (lower = better decoupling).
27
+ Returns (mean_fro, std_fro) across all class pairs.
28
+ """
29
+ device = class_pentachora[0].vertices.device
30
+ dtype = class_pentachora[0].vertices.dtype
31
+ C = len(class_pentachora)
32
+ U = []
33
+ for p in class_pentachora:
34
+ V = p.vertices.to(device=device, dtype=dtype) # [5,D]
35
+ c = V.mean(dim=0, keepdim=True)
36
+ A = V - c # [5,D]
37
+ # QR on D x 5 (A^T) → orthonormal basis in R^D
38
+ Q, _ = torch.linalg.qr(A.t(), mode='reduced') # [D, r]
39
+ U.append(Q)
40
+ overlaps = []
41
+ for a in range(C):
42
+ for b in range(a+1, C):
43
+ M = U[a].t() @ U[b] # [r_a, r_b]
44
+ overlaps.append(torch.linalg.norm(M, 'fro').item())
45
+ if not overlaps:
46
+ return 0.0, 0.0
47
+ return float(np.mean(overlaps)), float(np.std(overlaps))
48
+
49
+ @torch.no_grad()
50
+ def face_usage_heatmap(model, features_proj, targets, norm_type='l1'):
51
+ """
52
+ Compute per-class face (triad) usage heatmap [C,10].
53
+ features_proj: [N,D] L1-normalized (from model forward outputs)
54
+ """
55
+ device, dtype = features_proj.device, features_proj.dtype
56
+ C = model.num_classes
57
+ triplets = torch.tensor([
58
+ [0,1,2],[0,1,3],[0,1,4],
59
+ [0,2,3],[0,2,4],[0,3,4],
60
+ [1,2,3],[1,2,4],[1,3,4],
61
+ [2,3,4]
62
+ ], device=device, dtype=torch.long)
63
+ counts = torch.zeros(C, 10, device=device, dtype=torch.long)
64
+
65
+ for cls in torch.unique(targets):
66
+ idx = (targets == cls)
67
+ if idx.sum() == 0:
68
+ continue
69
+ f = features_proj[idx] # [b,D]
70
+ p = model.class_pentachora[int(cls)]
71
+ Vn = p.vertices_norm if norm_type == 'l1' else F.normalize(p.vertices, dim=-1) # [5,D]
72
+
73
+ # Build 10 faces
74
+ faces = []
75
+ for t in triplets:
76
+ b = (Vn[t[0]] + Vn[t[1]] + Vn[t[2]]) / 3.0
77
+ if norm_type == 'l1':
78
+ b = b / (b.abs().sum() + 1e-8)
79
+ else:
80
+ b = F.normalize(b.unsqueeze(0), dim=-1).squeeze(0)
81
+ faces.append(b)
82
+ F10 = torch.stack(faces, dim=0) # [10,D]
83
+ sims = f @ F10.t() # [b,10]
84
+ winner = sims.argmax(dim=1) # [b]
85
+ binc = torch.bincount(winner, minlength=10) # [10]
86
+ counts[int(cls)] += binc
87
+
88
+ counts = counts.float()
89
+ counts = counts / (counts.sum(dim=1, keepdim=True) + 1e-9)
90
+ return counts # [C,10]
91
+
92
+ @torch.no_grad()
93
+ def margin_stats(cos_pre, targets):
94
+ """
95
+ Compute margin Δ = pos - best_neg from PRE-margin cosines.
96
+ """
97
+ pos = cos_pre.gather(1, targets.view(-1,1)).squeeze(1) # [B]
98
+ masked = cos_pre.masked_fill(F.one_hot(targets, cos_pre.size(1)).bool(), -1e9)
99
+ neg = masked.max(dim=1).values # [B]
100
+ delta = pos - neg
101
+ return float(delta.mean()), float(delta.std())
102
+
103
+ # ============================================
104
+ # FEATURE EXTRACTION AND ANALYSIS (upgraded)
105
+ # ============================================
106
+
107
+ class FeatureAnalyzer:
108
+ """
109
+ Analyze feature capacity and geometric patterns.
110
+ Now aware of RoseFace:
111
+ - can run with or without margin at inference (margin_mode)
112
+ - stores pre-margin cosines and post-margin cosines
113
+ """
114
+ def __init__(self, model, dataloader, device=None, margin_mode='none'):
115
+ """
116
+ margin_mode:
117
+ 'none' -> don't pass targets to model (no margin at eval)
118
+ 'apply' -> pass targets (apply margin at eval)
119
+ 'both' -> run both (twice); store *_nomargin and *_margin
120
+ """
121
+ self.model = model
122
+ self.dataloader = dataloader
123
+ self.device = device or next(model.parameters()).device
124
+ self.model.eval()
125
+ assert margin_mode in ('none','apply','both')
126
+ self.margin_mode = margin_mode
127
+
128
+ def _forward_once(self, images, labels, apply_margin):
129
+ # forward; return dict of tensors on CPU
130
+ if apply_margin:
131
+ outputs = self.model(images, return_features=True, targets=labels)
132
+ else:
133
+ outputs = self.model(images, return_features=True)
134
+
135
+ out = {k: v.detach().cpu() for k, v in outputs.items() if isinstance(v, torch.Tensor)}
136
+ # Derive post-margin cosines (if RoseFace): cos_post = logits / s
137
+ if getattr(self.model, 'head_type', 'legacy') == 'roseface':
138
+ s = float(getattr(self.model, 'scale_s', 1.0))
139
+ if s > 0 and 'logits' in out:
140
+ out['cos_post'] = (out['logits'] / s)
141
+ return out
142
+
143
+ def extract_all_features(self, max_batches=None):
144
+ """
145
+ Extract features, pre-margin cosines, post-margin cosines (if available).
146
+ Returns dict with keys:
147
+ - cls_features
148
+ - features_proj
149
+ - similarities (pre-margin cos)
150
+ - cos_post (post-margin cos; RoseFace only)
151
+ - logits
152
+ - labels
153
+ If margin_mode == 'both', suffix *_nomargin / *_margin are included.
154
+ """
155
+ agg = {}
156
+
157
+ def append_batch(prefix, out_tensors, labels):
158
+ # initialize lists
159
+ for k, v in out_tensors.items():
160
+ agg.setdefault(f'{prefix}{k}', []).append(v)
161
+ agg.setdefault(f'{prefix}labels', []).append(labels.cpu())
162
+
163
+ with torch.no_grad():
164
+ for i, (images, labels) in enumerate(tqdm(self.dataloader, desc="Extracting features")):
165
+ if max_batches is not None and i >= max_batches:
166
+ break
167
+ images = images.to(self.device, non_blocking=True)
168
+ labels = labels.to(self.device, non_blocking=True)
169
+
170
+ if self.margin_mode in ('none', 'both'):
171
+ out0 = self._forward_once(images, labels, apply_margin=False)
172
+ append_batch('', out0, labels)
173
+
174
+ if self.margin_mode in ('apply', 'both'):
175
+ out1 = self._forward_once(images, labels, apply_margin=True)
176
+ append_batch('m_', out1, labels)
177
+
178
+ # concat
179
+ # concat everything we collected
180
+ for k, lst in agg.items():
181
+ agg[k] = torch.cat(lst, dim=0)
182
+
183
+ # helper: pick normal key, else 'm_' fallback
184
+ def pick(key: str):
185
+ return agg.get(key, agg.get(f"m_{key}", torch.empty(0)))
186
+
187
+ # unify view for downstream code
188
+ if self.margin_mode == 'both':
189
+ features = {
190
+ 'cls_features': pick('features'),
191
+ 'features_proj': pick('features_proj'),
192
+ 'similarities': agg.get('similarities', torch.empty(0)),
193
+ 'cos_post': agg.get('cos_post', torch.empty(0)),
194
+ 'labels': agg.get('labels', torch.empty(0)),
195
+ 'similarities_margin': agg.get('m_similarities', torch.empty(0)),
196
+ 'cos_post_margin': agg.get('m_cos_post', torch.empty(0)),
197
+ 'logits': agg.get('logits', torch.empty(0)),
198
+ 'logits_margin': agg.get('m_logits', torch.empty(0)),
199
+ }
200
+ else:
201
+ # works for BOTH margin_mode='none' and margin_mode='apply'
202
+ features = {
203
+ 'cls_features': pick('features'),
204
+ 'features_proj': pick('features_proj'),
205
+ 'similarities': pick('similarities'), # pre-margin cosines
206
+ 'cos_post': pick('cos_post'), # post-margin cosines (RoseFace)
207
+ 'labels': pick('labels'),
208
+ 'logits': pick('logits'),
209
+ }
210
+ return features
211
+
212
+
213
+ def analyze_feature_collapse(self, features):
214
+ print("\n=== FEATURE COLLAPSE ANALYSIS ===")
215
+ cls_features = features['cls_features'].numpy()
216
+ unique_patterns = self._count_unique_patterns(cls_features)
217
+ print(f"Estimated unique patterns: {unique_patterns}/100 classes")
218
+
219
+ feature_std = cls_features.std(axis=0).mean()
220
+ print(f"Average feature std: {feature_std:.4f}")
221
+
222
+ labels = features['labels'].numpy()
223
+ sample_size = min(1000, len(labels))
224
+ indices = np.random.choice(len(labels), sample_size, replace=False)
225
+ silhouette = silhouette_score(cls_features[indices], labels[indices])
226
+ print(f"Silhouette score: {silhouette:.3f}")
227
+
228
+ # centroid proximity count
229
+ class_features = {}
230
+ for i in range(100):
231
+ mask = labels == i
232
+ if mask.sum() > 0:
233
+ class_features[i] = cls_features[mask].mean(axis=0)
234
+ if class_features:
235
+ centroids = np.stack(list(class_features.values()))
236
+ d = np.linalg.norm(centroids[:, None] - centroids[None, :], axis=2)
237
+ thr = np.percentile(d[d > 0], 20)
238
+ close_pairs = (d < thr) & (d > 0)
239
+ classes_with_close_neighbors = close_pairs.sum(axis=1)
240
+ print(f"Classes with very similar features: {(classes_with_close_neighbors > 2).sum()}/100")
241
+
242
+ return {'unique_patterns': unique_patterns, 'feature_std': feature_std, 'silhouette': silhouette}
243
+
244
+ def analyze_geometric_patterns(self, features):
245
+ print("\n=== GEOMETRIC PATTERN ANALYSIS ===")
246
+ sims = features['similarities'] # pre-margin cosines [N,C]
247
+ print(f"Average max cosine: {sims.max(dim=1)[0].mean():.3f}")
248
+ print(f"Average min cosine: {sims.min(dim=1)[0].mean():.3f}")
249
+ print(f"Cosine std: {sims.std():.3f}")
250
+
251
+ high_sim_threshold = sims.mean() + sims.std()
252
+ high_sim_count = (sims > high_sim_threshold).sum(dim=1).float().mean()
253
+ print(f"Avg classes above (mean+std): {high_sim_count:.1f}/100")
254
+
255
+ labels = features['labels']
256
+ correct = sims.gather(1, labels.view(-1,1)).squeeze(1).mean().item()
257
+ wrong = (sims.sum(dim=1) - sims.gather(1, labels.view(-1,1)).squeeze(1)) / (sims.size(1)-1)
258
+ margin = (correct - wrong.mean().item())
259
+ print(f"Avg cosine margin (correct - mean wrong): {margin:.3f}")
260
+
261
+ # RoseFace-specific: if post cosines are present, compare deltas
262
+ if 'cos_post' in features and features['cos_post'].numel() > 0:
263
+ cos_post = features['cos_post']
264
+ # shift on target column
265
+ pos_pre = sims.gather(1, labels.view(-1,1))
266
+ pos_post = cos_post.gather(1, labels.view(-1,1))
267
+ shift = (pos_post - pos_pre).mean().item()
268
+ print(f"Avg target cosine shift (post - pre): {shift:.3f}")
269
+
270
+ return {
271
+ 'max_cos': sims.max(dim=1)[0].mean().item(),
272
+ 'cos_std': sims.std().item(),
273
+ 'high_sim_classes': high_sim_count.item(),
274
+ 'discrimination_margin': margin
275
+ }
276
+
277
+ def test_linear_probe(self, features, num_epochs=50):
278
+ print("\n=== LINEAR PROBE TEST ===")
279
+ X = features['cls_features']
280
+ y = features['labels']
281
+ n_train = int(0.8 * len(y))
282
+ X_train, y_train = X[:n_train], y[:n_train]
283
+ X_test, y_test = X[n_train:], y[n_train:]
284
+
285
+ probe = torch.nn.Linear(X_train.shape[1], 100).to(self.device)
286
+ opt = torch.optim.Adam(probe.parameters(), lr=0.01)
287
+
288
+ X_train = X_train.to(self.device); y_train = y_train.to(self.device)
289
+ X_test = X_test.to(self.device); y_test = y_test.to(self.device)
290
+
291
+ best = 0.0
292
+ for epoch in range(num_epochs):
293
+ logits = probe(X_train)
294
+ loss = F.cross_entropy(logits, y_train)
295
+ opt.zero_grad(); loss.backward(); opt.step()
296
+ if epoch % 10 == 0:
297
+ with torch.no_grad():
298
+ acc = (probe(X_test).argmax(dim=1) == y_test).float().mean().item()
299
+ best = max(best, acc)
300
+ print(f" Epoch {epoch}: Test acc = {acc*100:.1f}%")
301
+ with torch.no_grad():
302
+ final = (probe(X_test).argmax(dim=1) == y_test).float().mean().item()
303
+ best = max(best, final)
304
+ print(f"Best linear probe accuracy: {best*100:.1f}%")
305
+ return best
306
+
307
+ def visualize_features(self, features, method='tsne', n_samples=2000):
308
+ print(f"\n=== FEATURE VISUALIZATION ({method.upper()}) ===")
309
+ cls_features = features['cls_features'].numpy()
310
+ labels = features['labels'].numpy()
311
+ n_samples = min(n_samples, len(labels))
312
+ idx = np.random.choice(len(labels), n_samples, replace=False)
313
+ X = cls_features[idx]; y = labels[idx]
314
+
315
+ print(f"Reducing {n_samples} samples to 2D...")
316
+ reducer = TSNE(n_components=2, random_state=42, perplexity=30) if method=='tsne' else PCA(n_components=2)
317
+ X2 = reducer.fit_transform(X)
318
+
319
+ plt.figure(figsize=(12,9))
320
+ scatter = plt.scatter(X2[:,0], X2[:,1], c=y, cmap='nipy_spectral', alpha=0.6, s=15)
321
+ plt.title(f'Feature Space Visualization ({method.upper()})'); plt.xlabel('Comp 1'); plt.ylabel('Comp 2')
322
+
323
+ print("Estimating visual clusters...")
324
+ silhouette_scores, K = [], list(range(30, 60, 5))
325
+ for k in K:
326
+ kmeans = KMeans(n_clusters=k, random_state=42, n_init=3)
327
+ cls_lbl = kmeans.fit_predict(X2)
328
+ silhouette_scores.append(silhouette_score(X2, cls_lbl))
329
+ best_k = K[int(np.argmax(silhouette_scores))]
330
+ kmeans = KMeans(n_clusters=best_k, random_state=42, n_init=5)
331
+ cluster_labels = kmeans.fit_predict(X2)
332
+ n_populated = len(np.unique(cluster_labels))
333
+ plt.text(0.02, 0.98, f'Estimated clusters: {n_populated}', transform=plt.gca().transAxes,
334
+ va='top', fontsize=12, bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
335
+ cbar = plt.colorbar(scatter, ticks=np.arange(0,100,10)); cbar.set_label('Class', rotation=270, labelpad=15)
336
+ plt.tight_layout(); plt.show()
337
+ return X2, n_populated
338
+
339
+ def analyze_pentachora_usage(self):
340
+ print("\n=== PENTACHORA USAGE ANALYSIS ===")
341
+ print(f"Classes: {self.model.num_classes}")
342
+ print(f"Embed dim: {self.model.embed_dim} | Penta dim: {self.model.pentachora_dim}")
343
+ print(f"Head: {getattr(self.model,'head_type','legacy')} | Prototype: {getattr(self.model,'prototype_mode','n/a')} | Margin: {getattr(self.model,'margin_type','n/a')}")
344
+ if hasattr(self.model, 'to_pentachora_dim'):
345
+ if isinstance(self.model.to_pentachora_dim, torch.nn.Linear):
346
+ print(f"Projection: Linear {self.model.embed_dim}→{self.model.pentachora_dim}")
347
+ else:
348
+ print("Projection: Identity")
349
+
350
+ # Inter-class centroid similarity (legacy view)
351
+ centroids = self.model.get_class_centroids()
352
+ sim = centroids @ centroids.t()
353
+ mask = ~torch.eye(self.model.num_classes, dtype=bool, device=sim.device)
354
+ off = sim[mask]
355
+ print(f"\nCentroid sims: mean={off.mean():.3f} max={off.max():.3f} min={off.min():.3f}")
356
+
357
+ # Principal-angle overlap
358
+ mean_fro, std_fro = principal_angle_overlap(self.model.class_pentachora)
359
+ print(f"Principal-angle Fro overlap: mean={mean_fro:.3f} ± {std_fro:.3f} (lower is better)")
360
+
361
+ return {'mean_similarity': off.mean().item(), 'max_similarity': off.max().item(), 'mean_fro_overlap': mean_fro}
362
+
363
+ def run_full_analysis(self):
364
+ print("="*60); print("COMPREHENSIVE FEATURE ANALYSIS"); print("="*60)
365
+ print("\nExtracting features (margin_mode =", self.margin_mode, ") ...")
366
+ feats = self.extract_all_features(max_batches=50)
367
+ print(f"✓ Extracted features from {len(feats['labels'])} samples")
368
+
369
+ res = {}
370
+ res['collapse'] = self.analyze_feature_collapse(feats)
371
+ res['geometric'] = self.analyze_geometric_patterns(feats)
372
+
373
+ # Margin stats from PRE-margin cosines
374
+ mu, sig = margin_stats(feats['similarities'], feats['labels'])
375
+ print(f"PRE-margin Δ (pos - bestneg): mean={mu:.3f}, std={sig:.3f}")
376
+
377
+ # Face-usage heatmap
378
+ if 'features_proj' in feats and feats['features_proj'].numel() > 0:
379
+ heat = face_usage_heatmap(self.model, feats['features_proj'].to(self.device), feats['labels'].to(self.device), norm_type=getattr(self.model,'norm_type','l1'))
380
+ print("Face-usage heatmap computed [C,10] (display top 3 classes by mass):")
381
+ class_mass = heat.sum(dim=1)
382
+ top3 = torch.topk(class_mass, k=min(3, heat.size(0))).indices.tolist()
383
+ for c in top3:
384
+ print(f" class {c}: {heat[c].cpu().numpy().round(3)}")
385
+
386
+ res['linear_probe'] = self.test_linear_probe(feats)
387
+ res['pentachora'] = self.analyze_pentachora_usage()
388
+
389
+ # Visualizations
390
+ _, n_tsne = self.visualize_features(feats, 'tsne')
391
+ _, n_pca = self.visualize_features(feats, 'pca')
392
+ res['visual_clusters'] = {'tsne': n_tsne, 'pca': n_pca}
393
+
394
+ # Summary
395
+ print("\n" + "="*60); print("DIAGNOSIS SUMMARY"); print("="*60)
396
+ up = res['collapse']['unique_patterns']; lp = res['linear_probe']
397
+ if up <= 45 and lp <= 0.42:
398
+ print(f"🔴 Compact regime: {up} unique patterns; linear probe {lp*100:.1f}%")
399
+ elif up > 60:
400
+ print(f"✅ Diverse regime: {up} unique patterns; linear probe {lp*100:.1f}%")
401
+ else:
402
+ print(f"⚡ Partial bottleneck: {up} unique patterns; linear probe {lp*100:.1f}%")
403
+
404
+ return res
405
+
406
+ # ------------------------------
407
+ # Helpers (unchanged interface)
408
+ # ------------------------------
409
+ def _count_unique_patterns(self, features, method='elbow'):
410
+ X = features[:min(3000, len(features))]
411
+ if method == 'elbow':
412
+ inertias, K = [], list(range(20, 80, 5))
413
+ for k in K:
414
+ km = KMeans(n_clusters=k, random_state=42, n_init=3)
415
+ km.fit(X); inertias.append(km.inertia_)
416
+ diffs = np.diff(inertias); diffs2 = np.diff(diffs)
417
+ if len(diffs2) > 0:
418
+ elbow_idx = int(np.argmax(np.abs(diffs2))) + 1
419
+ est = K[elbow_idx]
420
+ else:
421
+ est = 41
422
+ else:
423
+ scores, K = [], list(range(30, 60, 5))
424
+ for k in K:
425
+ km = KMeans(n_clusters=k, random_state=42, n_init=3)
426
+ lbl = km.fit_predict(X)
427
+ scores.append(silhouette_score(X, lbl))
428
+ est = K[int(np.argmax(scores))]
429
+ return est
430
+
431
+ # ============================================
432
+ # QUICK TEST (RoseFace-aware)
433
+ # ============================================
434
+
435
+ def quick_41_percent_test(model, test_loader, device=None, apply_margin_eval=False):
436
+ """
437
+ If apply_margin_eval=True, pass targets to model at eval (margin applied).
438
+ Otherwise, evaluate without margin (classic).
439
+ """
440
+ print("="*60); print("41% ACCURACY CAP HYPOTHESIS TEST"); print("="*60)
441
+ model.eval()
442
+ device = device or next(model.parameters()).device
443
+
444
+ # 1) Accuracy
445
+ print("\n1. Verifying model accuracy...")
446
+ correct, total = 0, 0
447
+ with torch.no_grad():
448
+ for images, labels in tqdm(test_loader, desc="Testing"):
449
+ images = images.to(device)
450
+ labels = labels.to(device)
451
+ outputs = model(images, targets=labels) if apply_margin_eval else model(images)
452
+ pred = outputs['logits'].argmax(dim=1)
453
+ correct += (pred == labels).sum().item()
454
+ total += labels.size(0)
455
+ acc = 100 * correct / total
456
+ policy = "WITH margin" if apply_margin_eval else "NO margin"
457
+ print(f" Test Accuracy ({policy}): {acc:.2f}%")
458
+
459
+ is_at_cap = abs(acc - 41.0) < 3.0
460
+ # 2) Focused analysis (small sample)
461
+ print("\n2. Analyzing feature patterns...")
462
+ margin_mode = 'apply' if apply_margin_eval else 'none'
463
+ analyzer = FeatureAnalyzer(model, test_loader, device=device, margin_mode=margin_mode)
464
+ feats = analyzer.extract_all_features(max_batches=20)
465
+ acc_rose5 = offline_head_eval_rose5(model, feats['features_proj'].to(device), feats['labels'])
466
+ print(f"Offline prototype eval (rose5, no margin): {acc_rose5*100:.2f}%")
467
+ collapse = analyzer.analyze_feature_collapse(feats)
468
+ pent = analyzer.analyze_pentachora_usage()
469
+
470
+ print("\n" + "="*60); print("VERDICT"); print("="*60)
471
+ if is_at_cap and collapse['unique_patterns'] <= 45:
472
+ print("🔴 41% CAP CONFIRMED")
473
+ print(f" Acc: {acc:.1f}% | Unique patterns: {collapse['unique_patterns']}")
474
+ print(" Likely geometric bottleneck.")
475
+ elif collapse['unique_patterns'] <= 45:
476
+ print("⚠️ FEATURE BOTTLENECK DETECTED")
477
+ print(f" {collapse['unique_patterns']} patterns; Acc={acc:.1f}%")
478
+ else:
479
+ print("✅ NO 41% BOTTLENECK")
480
+ print(f" {collapse['unique_patterns']} patterns; Acc={acc:.1f}%")
481
+
482
+ return {
483
+ 'accuracy': acc,
484
+ 'unique_patterns': collapse['unique_patterns'],
485
+ 'is_bottlenecked': collapse['unique_patterns'] <= 45,
486
+ 'pentachora_similarity': pent['mean_similarity']
487
+ }
488
+
489
+ @torch.no_grad()
490
+ def offline_head_eval_rose5(model, features_proj, labels):
491
+ # compute z_l2 (dual-norm bridge)
492
+ z = features_proj
493
+ z_l2 = z / (z.norm(p=2, dim=-1, keepdim=True) + 1e-12)
494
+ # build rose5 prototypes [C,D]
495
+ V = torch.stack([p.vertices for p in model.class_pentachora], dim=0).to(z.device, z.dtype) # [C,5,D]
496
+ V = V / (V.norm(p=2, dim=-1, keepdim=True) + 1e-12)
497
+ W = model.rose_face_weights.to(z.device, z.dtype) # [10,5]
498
+ faces = torch.einsum('tf,cfd->ctd', W, V)
499
+ proto = (V.mean(dim=1) + 0.5 * faces.mean(dim=1))
500
+ proto = proto / (proto.norm(p=2, dim=-1, keepdim=True) + 1e-12) # [C,D]
501
+ cos = z_l2 @ proto.t() # [N,C]
502
+ acc = (cos.argmax(dim=1) == labels.to(z.device)).float().mean().item()
503
+ return acc
504
+
505
+ # ==========================
506
+ # RUN ANALYSIS (example)
507
+ # ==========================
508
+
509
+ if __name__ == "__main__":
510
+ print("Starting RoseFace-aware feature analysis...")
511
+ print(f"Model device: {next(model.parameters()).device}")
512
+ # Dataloaders
513
+ train_loader, test_loader, train_transforms = get_cifar100_dataloaders(batch_size=128)
514
+
515
+ # Quick test in BOTH modes (optional): compare accuracy
516
+ print("\nRunning quick 41% hypothesis test (NO margin at eval)...")
517
+ res_nom = quick_41_percent_test(model, test_loader, apply_margin_eval=False)
518
+
519
+ print("\nRunning quick 41% hypothesis test (WITH margin at eval)...")
520
+ res_mar = quick_41_percent_test(model, test_loader, apply_margin_eval=True)
521
+
522
+ # Full analysis with richer diagnostics (no margin at eval is typical)
523
+ analyzer = FeatureAnalyzer(model, test_loader, margin_mode='none')
524
+ full_results = analyzer.run_full_analysis()