Create probe.py
Browse files- code/probe.py +524 -0
code/probe.py
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Comprehensive testing cell for BaselineViT (RoseFace-aware)
|
| 3 |
+
Run AFTER loading your model & checkpoint in Colab.
|
| 4 |
+
Assumes: model, get_cifar100_dataloaders are defined.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import numpy as np
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
from sklearn.manifold import TSNE
|
| 12 |
+
from sklearn.decomposition import PCA
|
| 13 |
+
from sklearn.cluster import KMeans
|
| 14 |
+
from sklearn.metrics import silhouette_score
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
import json
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
# =========================
|
| 20 |
+
# RoseFace-aware utilities
|
| 21 |
+
# =========================
|
| 22 |
+
|
| 23 |
+
@torch.no_grad()
|
| 24 |
+
def principal_angle_overlap(class_pentachora):
|
| 25 |
+
"""
|
| 26 |
+
Measure subspace overlap between classes (lower = better decoupling).
|
| 27 |
+
Returns (mean_fro, std_fro) across all class pairs.
|
| 28 |
+
"""
|
| 29 |
+
device = class_pentachora[0].vertices.device
|
| 30 |
+
dtype = class_pentachora[0].vertices.dtype
|
| 31 |
+
C = len(class_pentachora)
|
| 32 |
+
U = []
|
| 33 |
+
for p in class_pentachora:
|
| 34 |
+
V = p.vertices.to(device=device, dtype=dtype) # [5,D]
|
| 35 |
+
c = V.mean(dim=0, keepdim=True)
|
| 36 |
+
A = V - c # [5,D]
|
| 37 |
+
# QR on D x 5 (A^T) → orthonormal basis in R^D
|
| 38 |
+
Q, _ = torch.linalg.qr(A.t(), mode='reduced') # [D, r]
|
| 39 |
+
U.append(Q)
|
| 40 |
+
overlaps = []
|
| 41 |
+
for a in range(C):
|
| 42 |
+
for b in range(a+1, C):
|
| 43 |
+
M = U[a].t() @ U[b] # [r_a, r_b]
|
| 44 |
+
overlaps.append(torch.linalg.norm(M, 'fro').item())
|
| 45 |
+
if not overlaps:
|
| 46 |
+
return 0.0, 0.0
|
| 47 |
+
return float(np.mean(overlaps)), float(np.std(overlaps))
|
| 48 |
+
|
| 49 |
+
@torch.no_grad()
|
| 50 |
+
def face_usage_heatmap(model, features_proj, targets, norm_type='l1'):
|
| 51 |
+
"""
|
| 52 |
+
Compute per-class face (triad) usage heatmap [C,10].
|
| 53 |
+
features_proj: [N,D] L1-normalized (from model forward outputs)
|
| 54 |
+
"""
|
| 55 |
+
device, dtype = features_proj.device, features_proj.dtype
|
| 56 |
+
C = model.num_classes
|
| 57 |
+
triplets = torch.tensor([
|
| 58 |
+
[0,1,2],[0,1,3],[0,1,4],
|
| 59 |
+
[0,2,3],[0,2,4],[0,3,4],
|
| 60 |
+
[1,2,3],[1,2,4],[1,3,4],
|
| 61 |
+
[2,3,4]
|
| 62 |
+
], device=device, dtype=torch.long)
|
| 63 |
+
counts = torch.zeros(C, 10, device=device, dtype=torch.long)
|
| 64 |
+
|
| 65 |
+
for cls in torch.unique(targets):
|
| 66 |
+
idx = (targets == cls)
|
| 67 |
+
if idx.sum() == 0:
|
| 68 |
+
continue
|
| 69 |
+
f = features_proj[idx] # [b,D]
|
| 70 |
+
p = model.class_pentachora[int(cls)]
|
| 71 |
+
Vn = p.vertices_norm if norm_type == 'l1' else F.normalize(p.vertices, dim=-1) # [5,D]
|
| 72 |
+
|
| 73 |
+
# Build 10 faces
|
| 74 |
+
faces = []
|
| 75 |
+
for t in triplets:
|
| 76 |
+
b = (Vn[t[0]] + Vn[t[1]] + Vn[t[2]]) / 3.0
|
| 77 |
+
if norm_type == 'l1':
|
| 78 |
+
b = b / (b.abs().sum() + 1e-8)
|
| 79 |
+
else:
|
| 80 |
+
b = F.normalize(b.unsqueeze(0), dim=-1).squeeze(0)
|
| 81 |
+
faces.append(b)
|
| 82 |
+
F10 = torch.stack(faces, dim=0) # [10,D]
|
| 83 |
+
sims = f @ F10.t() # [b,10]
|
| 84 |
+
winner = sims.argmax(dim=1) # [b]
|
| 85 |
+
binc = torch.bincount(winner, minlength=10) # [10]
|
| 86 |
+
counts[int(cls)] += binc
|
| 87 |
+
|
| 88 |
+
counts = counts.float()
|
| 89 |
+
counts = counts / (counts.sum(dim=1, keepdim=True) + 1e-9)
|
| 90 |
+
return counts # [C,10]
|
| 91 |
+
|
| 92 |
+
@torch.no_grad()
|
| 93 |
+
def margin_stats(cos_pre, targets):
|
| 94 |
+
"""
|
| 95 |
+
Compute margin Δ = pos - best_neg from PRE-margin cosines.
|
| 96 |
+
"""
|
| 97 |
+
pos = cos_pre.gather(1, targets.view(-1,1)).squeeze(1) # [B]
|
| 98 |
+
masked = cos_pre.masked_fill(F.one_hot(targets, cos_pre.size(1)).bool(), -1e9)
|
| 99 |
+
neg = masked.max(dim=1).values # [B]
|
| 100 |
+
delta = pos - neg
|
| 101 |
+
return float(delta.mean()), float(delta.std())
|
| 102 |
+
|
| 103 |
+
# ============================================
|
| 104 |
+
# FEATURE EXTRACTION AND ANALYSIS (upgraded)
|
| 105 |
+
# ============================================
|
| 106 |
+
|
| 107 |
+
class FeatureAnalyzer:
|
| 108 |
+
"""
|
| 109 |
+
Analyze feature capacity and geometric patterns.
|
| 110 |
+
Now aware of RoseFace:
|
| 111 |
+
- can run with or without margin at inference (margin_mode)
|
| 112 |
+
- stores pre-margin cosines and post-margin cosines
|
| 113 |
+
"""
|
| 114 |
+
def __init__(self, model, dataloader, device=None, margin_mode='none'):
|
| 115 |
+
"""
|
| 116 |
+
margin_mode:
|
| 117 |
+
'none' -> don't pass targets to model (no margin at eval)
|
| 118 |
+
'apply' -> pass targets (apply margin at eval)
|
| 119 |
+
'both' -> run both (twice); store *_nomargin and *_margin
|
| 120 |
+
"""
|
| 121 |
+
self.model = model
|
| 122 |
+
self.dataloader = dataloader
|
| 123 |
+
self.device = device or next(model.parameters()).device
|
| 124 |
+
self.model.eval()
|
| 125 |
+
assert margin_mode in ('none','apply','both')
|
| 126 |
+
self.margin_mode = margin_mode
|
| 127 |
+
|
| 128 |
+
def _forward_once(self, images, labels, apply_margin):
|
| 129 |
+
# forward; return dict of tensors on CPU
|
| 130 |
+
if apply_margin:
|
| 131 |
+
outputs = self.model(images, return_features=True, targets=labels)
|
| 132 |
+
else:
|
| 133 |
+
outputs = self.model(images, return_features=True)
|
| 134 |
+
|
| 135 |
+
out = {k: v.detach().cpu() for k, v in outputs.items() if isinstance(v, torch.Tensor)}
|
| 136 |
+
# Derive post-margin cosines (if RoseFace): cos_post = logits / s
|
| 137 |
+
if getattr(self.model, 'head_type', 'legacy') == 'roseface':
|
| 138 |
+
s = float(getattr(self.model, 'scale_s', 1.0))
|
| 139 |
+
if s > 0 and 'logits' in out:
|
| 140 |
+
out['cos_post'] = (out['logits'] / s)
|
| 141 |
+
return out
|
| 142 |
+
|
| 143 |
+
def extract_all_features(self, max_batches=None):
|
| 144 |
+
"""
|
| 145 |
+
Extract features, pre-margin cosines, post-margin cosines (if available).
|
| 146 |
+
Returns dict with keys:
|
| 147 |
+
- cls_features
|
| 148 |
+
- features_proj
|
| 149 |
+
- similarities (pre-margin cos)
|
| 150 |
+
- cos_post (post-margin cos; RoseFace only)
|
| 151 |
+
- logits
|
| 152 |
+
- labels
|
| 153 |
+
If margin_mode == 'both', suffix *_nomargin / *_margin are included.
|
| 154 |
+
"""
|
| 155 |
+
agg = {}
|
| 156 |
+
|
| 157 |
+
def append_batch(prefix, out_tensors, labels):
|
| 158 |
+
# initialize lists
|
| 159 |
+
for k, v in out_tensors.items():
|
| 160 |
+
agg.setdefault(f'{prefix}{k}', []).append(v)
|
| 161 |
+
agg.setdefault(f'{prefix}labels', []).append(labels.cpu())
|
| 162 |
+
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
for i, (images, labels) in enumerate(tqdm(self.dataloader, desc="Extracting features")):
|
| 165 |
+
if max_batches is not None and i >= max_batches:
|
| 166 |
+
break
|
| 167 |
+
images = images.to(self.device, non_blocking=True)
|
| 168 |
+
labels = labels.to(self.device, non_blocking=True)
|
| 169 |
+
|
| 170 |
+
if self.margin_mode in ('none', 'both'):
|
| 171 |
+
out0 = self._forward_once(images, labels, apply_margin=False)
|
| 172 |
+
append_batch('', out0, labels)
|
| 173 |
+
|
| 174 |
+
if self.margin_mode in ('apply', 'both'):
|
| 175 |
+
out1 = self._forward_once(images, labels, apply_margin=True)
|
| 176 |
+
append_batch('m_', out1, labels)
|
| 177 |
+
|
| 178 |
+
# concat
|
| 179 |
+
# concat everything we collected
|
| 180 |
+
for k, lst in agg.items():
|
| 181 |
+
agg[k] = torch.cat(lst, dim=0)
|
| 182 |
+
|
| 183 |
+
# helper: pick normal key, else 'm_' fallback
|
| 184 |
+
def pick(key: str):
|
| 185 |
+
return agg.get(key, agg.get(f"m_{key}", torch.empty(0)))
|
| 186 |
+
|
| 187 |
+
# unify view for downstream code
|
| 188 |
+
if self.margin_mode == 'both':
|
| 189 |
+
features = {
|
| 190 |
+
'cls_features': pick('features'),
|
| 191 |
+
'features_proj': pick('features_proj'),
|
| 192 |
+
'similarities': agg.get('similarities', torch.empty(0)),
|
| 193 |
+
'cos_post': agg.get('cos_post', torch.empty(0)),
|
| 194 |
+
'labels': agg.get('labels', torch.empty(0)),
|
| 195 |
+
'similarities_margin': agg.get('m_similarities', torch.empty(0)),
|
| 196 |
+
'cos_post_margin': agg.get('m_cos_post', torch.empty(0)),
|
| 197 |
+
'logits': agg.get('logits', torch.empty(0)),
|
| 198 |
+
'logits_margin': agg.get('m_logits', torch.empty(0)),
|
| 199 |
+
}
|
| 200 |
+
else:
|
| 201 |
+
# works for BOTH margin_mode='none' and margin_mode='apply'
|
| 202 |
+
features = {
|
| 203 |
+
'cls_features': pick('features'),
|
| 204 |
+
'features_proj': pick('features_proj'),
|
| 205 |
+
'similarities': pick('similarities'), # pre-margin cosines
|
| 206 |
+
'cos_post': pick('cos_post'), # post-margin cosines (RoseFace)
|
| 207 |
+
'labels': pick('labels'),
|
| 208 |
+
'logits': pick('logits'),
|
| 209 |
+
}
|
| 210 |
+
return features
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def analyze_feature_collapse(self, features):
|
| 214 |
+
print("\n=== FEATURE COLLAPSE ANALYSIS ===")
|
| 215 |
+
cls_features = features['cls_features'].numpy()
|
| 216 |
+
unique_patterns = self._count_unique_patterns(cls_features)
|
| 217 |
+
print(f"Estimated unique patterns: {unique_patterns}/100 classes")
|
| 218 |
+
|
| 219 |
+
feature_std = cls_features.std(axis=0).mean()
|
| 220 |
+
print(f"Average feature std: {feature_std:.4f}")
|
| 221 |
+
|
| 222 |
+
labels = features['labels'].numpy()
|
| 223 |
+
sample_size = min(1000, len(labels))
|
| 224 |
+
indices = np.random.choice(len(labels), sample_size, replace=False)
|
| 225 |
+
silhouette = silhouette_score(cls_features[indices], labels[indices])
|
| 226 |
+
print(f"Silhouette score: {silhouette:.3f}")
|
| 227 |
+
|
| 228 |
+
# centroid proximity count
|
| 229 |
+
class_features = {}
|
| 230 |
+
for i in range(100):
|
| 231 |
+
mask = labels == i
|
| 232 |
+
if mask.sum() > 0:
|
| 233 |
+
class_features[i] = cls_features[mask].mean(axis=0)
|
| 234 |
+
if class_features:
|
| 235 |
+
centroids = np.stack(list(class_features.values()))
|
| 236 |
+
d = np.linalg.norm(centroids[:, None] - centroids[None, :], axis=2)
|
| 237 |
+
thr = np.percentile(d[d > 0], 20)
|
| 238 |
+
close_pairs = (d < thr) & (d > 0)
|
| 239 |
+
classes_with_close_neighbors = close_pairs.sum(axis=1)
|
| 240 |
+
print(f"Classes with very similar features: {(classes_with_close_neighbors > 2).sum()}/100")
|
| 241 |
+
|
| 242 |
+
return {'unique_patterns': unique_patterns, 'feature_std': feature_std, 'silhouette': silhouette}
|
| 243 |
+
|
| 244 |
+
def analyze_geometric_patterns(self, features):
|
| 245 |
+
print("\n=== GEOMETRIC PATTERN ANALYSIS ===")
|
| 246 |
+
sims = features['similarities'] # pre-margin cosines [N,C]
|
| 247 |
+
print(f"Average max cosine: {sims.max(dim=1)[0].mean():.3f}")
|
| 248 |
+
print(f"Average min cosine: {sims.min(dim=1)[0].mean():.3f}")
|
| 249 |
+
print(f"Cosine std: {sims.std():.3f}")
|
| 250 |
+
|
| 251 |
+
high_sim_threshold = sims.mean() + sims.std()
|
| 252 |
+
high_sim_count = (sims > high_sim_threshold).sum(dim=1).float().mean()
|
| 253 |
+
print(f"Avg classes above (mean+std): {high_sim_count:.1f}/100")
|
| 254 |
+
|
| 255 |
+
labels = features['labels']
|
| 256 |
+
correct = sims.gather(1, labels.view(-1,1)).squeeze(1).mean().item()
|
| 257 |
+
wrong = (sims.sum(dim=1) - sims.gather(1, labels.view(-1,1)).squeeze(1)) / (sims.size(1)-1)
|
| 258 |
+
margin = (correct - wrong.mean().item())
|
| 259 |
+
print(f"Avg cosine margin (correct - mean wrong): {margin:.3f}")
|
| 260 |
+
|
| 261 |
+
# RoseFace-specific: if post cosines are present, compare deltas
|
| 262 |
+
if 'cos_post' in features and features['cos_post'].numel() > 0:
|
| 263 |
+
cos_post = features['cos_post']
|
| 264 |
+
# shift on target column
|
| 265 |
+
pos_pre = sims.gather(1, labels.view(-1,1))
|
| 266 |
+
pos_post = cos_post.gather(1, labels.view(-1,1))
|
| 267 |
+
shift = (pos_post - pos_pre).mean().item()
|
| 268 |
+
print(f"Avg target cosine shift (post - pre): {shift:.3f}")
|
| 269 |
+
|
| 270 |
+
return {
|
| 271 |
+
'max_cos': sims.max(dim=1)[0].mean().item(),
|
| 272 |
+
'cos_std': sims.std().item(),
|
| 273 |
+
'high_sim_classes': high_sim_count.item(),
|
| 274 |
+
'discrimination_margin': margin
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
def test_linear_probe(self, features, num_epochs=50):
|
| 278 |
+
print("\n=== LINEAR PROBE TEST ===")
|
| 279 |
+
X = features['cls_features']
|
| 280 |
+
y = features['labels']
|
| 281 |
+
n_train = int(0.8 * len(y))
|
| 282 |
+
X_train, y_train = X[:n_train], y[:n_train]
|
| 283 |
+
X_test, y_test = X[n_train:], y[n_train:]
|
| 284 |
+
|
| 285 |
+
probe = torch.nn.Linear(X_train.shape[1], 100).to(self.device)
|
| 286 |
+
opt = torch.optim.Adam(probe.parameters(), lr=0.01)
|
| 287 |
+
|
| 288 |
+
X_train = X_train.to(self.device); y_train = y_train.to(self.device)
|
| 289 |
+
X_test = X_test.to(self.device); y_test = y_test.to(self.device)
|
| 290 |
+
|
| 291 |
+
best = 0.0
|
| 292 |
+
for epoch in range(num_epochs):
|
| 293 |
+
logits = probe(X_train)
|
| 294 |
+
loss = F.cross_entropy(logits, y_train)
|
| 295 |
+
opt.zero_grad(); loss.backward(); opt.step()
|
| 296 |
+
if epoch % 10 == 0:
|
| 297 |
+
with torch.no_grad():
|
| 298 |
+
acc = (probe(X_test).argmax(dim=1) == y_test).float().mean().item()
|
| 299 |
+
best = max(best, acc)
|
| 300 |
+
print(f" Epoch {epoch}: Test acc = {acc*100:.1f}%")
|
| 301 |
+
with torch.no_grad():
|
| 302 |
+
final = (probe(X_test).argmax(dim=1) == y_test).float().mean().item()
|
| 303 |
+
best = max(best, final)
|
| 304 |
+
print(f"Best linear probe accuracy: {best*100:.1f}%")
|
| 305 |
+
return best
|
| 306 |
+
|
| 307 |
+
def visualize_features(self, features, method='tsne', n_samples=2000):
|
| 308 |
+
print(f"\n=== FEATURE VISUALIZATION ({method.upper()}) ===")
|
| 309 |
+
cls_features = features['cls_features'].numpy()
|
| 310 |
+
labels = features['labels'].numpy()
|
| 311 |
+
n_samples = min(n_samples, len(labels))
|
| 312 |
+
idx = np.random.choice(len(labels), n_samples, replace=False)
|
| 313 |
+
X = cls_features[idx]; y = labels[idx]
|
| 314 |
+
|
| 315 |
+
print(f"Reducing {n_samples} samples to 2D...")
|
| 316 |
+
reducer = TSNE(n_components=2, random_state=42, perplexity=30) if method=='tsne' else PCA(n_components=2)
|
| 317 |
+
X2 = reducer.fit_transform(X)
|
| 318 |
+
|
| 319 |
+
plt.figure(figsize=(12,9))
|
| 320 |
+
scatter = plt.scatter(X2[:,0], X2[:,1], c=y, cmap='nipy_spectral', alpha=0.6, s=15)
|
| 321 |
+
plt.title(f'Feature Space Visualization ({method.upper()})'); plt.xlabel('Comp 1'); plt.ylabel('Comp 2')
|
| 322 |
+
|
| 323 |
+
print("Estimating visual clusters...")
|
| 324 |
+
silhouette_scores, K = [], list(range(30, 60, 5))
|
| 325 |
+
for k in K:
|
| 326 |
+
kmeans = KMeans(n_clusters=k, random_state=42, n_init=3)
|
| 327 |
+
cls_lbl = kmeans.fit_predict(X2)
|
| 328 |
+
silhouette_scores.append(silhouette_score(X2, cls_lbl))
|
| 329 |
+
best_k = K[int(np.argmax(silhouette_scores))]
|
| 330 |
+
kmeans = KMeans(n_clusters=best_k, random_state=42, n_init=5)
|
| 331 |
+
cluster_labels = kmeans.fit_predict(X2)
|
| 332 |
+
n_populated = len(np.unique(cluster_labels))
|
| 333 |
+
plt.text(0.02, 0.98, f'Estimated clusters: {n_populated}', transform=plt.gca().transAxes,
|
| 334 |
+
va='top', fontsize=12, bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
|
| 335 |
+
cbar = plt.colorbar(scatter, ticks=np.arange(0,100,10)); cbar.set_label('Class', rotation=270, labelpad=15)
|
| 336 |
+
plt.tight_layout(); plt.show()
|
| 337 |
+
return X2, n_populated
|
| 338 |
+
|
| 339 |
+
def analyze_pentachora_usage(self):
|
| 340 |
+
print("\n=== PENTACHORA USAGE ANALYSIS ===")
|
| 341 |
+
print(f"Classes: {self.model.num_classes}")
|
| 342 |
+
print(f"Embed dim: {self.model.embed_dim} | Penta dim: {self.model.pentachora_dim}")
|
| 343 |
+
print(f"Head: {getattr(self.model,'head_type','legacy')} | Prototype: {getattr(self.model,'prototype_mode','n/a')} | Margin: {getattr(self.model,'margin_type','n/a')}")
|
| 344 |
+
if hasattr(self.model, 'to_pentachora_dim'):
|
| 345 |
+
if isinstance(self.model.to_pentachora_dim, torch.nn.Linear):
|
| 346 |
+
print(f"Projection: Linear {self.model.embed_dim}→{self.model.pentachora_dim}")
|
| 347 |
+
else:
|
| 348 |
+
print("Projection: Identity")
|
| 349 |
+
|
| 350 |
+
# Inter-class centroid similarity (legacy view)
|
| 351 |
+
centroids = self.model.get_class_centroids()
|
| 352 |
+
sim = centroids @ centroids.t()
|
| 353 |
+
mask = ~torch.eye(self.model.num_classes, dtype=bool, device=sim.device)
|
| 354 |
+
off = sim[mask]
|
| 355 |
+
print(f"\nCentroid sims: mean={off.mean():.3f} max={off.max():.3f} min={off.min():.3f}")
|
| 356 |
+
|
| 357 |
+
# Principal-angle overlap
|
| 358 |
+
mean_fro, std_fro = principal_angle_overlap(self.model.class_pentachora)
|
| 359 |
+
print(f"Principal-angle Fro overlap: mean={mean_fro:.3f} ± {std_fro:.3f} (lower is better)")
|
| 360 |
+
|
| 361 |
+
return {'mean_similarity': off.mean().item(), 'max_similarity': off.max().item(), 'mean_fro_overlap': mean_fro}
|
| 362 |
+
|
| 363 |
+
def run_full_analysis(self):
|
| 364 |
+
print("="*60); print("COMPREHENSIVE FEATURE ANALYSIS"); print("="*60)
|
| 365 |
+
print("\nExtracting features (margin_mode =", self.margin_mode, ") ...")
|
| 366 |
+
feats = self.extract_all_features(max_batches=50)
|
| 367 |
+
print(f"✓ Extracted features from {len(feats['labels'])} samples")
|
| 368 |
+
|
| 369 |
+
res = {}
|
| 370 |
+
res['collapse'] = self.analyze_feature_collapse(feats)
|
| 371 |
+
res['geometric'] = self.analyze_geometric_patterns(feats)
|
| 372 |
+
|
| 373 |
+
# Margin stats from PRE-margin cosines
|
| 374 |
+
mu, sig = margin_stats(feats['similarities'], feats['labels'])
|
| 375 |
+
print(f"PRE-margin Δ (pos - bestneg): mean={mu:.3f}, std={sig:.3f}")
|
| 376 |
+
|
| 377 |
+
# Face-usage heatmap
|
| 378 |
+
if 'features_proj' in feats and feats['features_proj'].numel() > 0:
|
| 379 |
+
heat = face_usage_heatmap(self.model, feats['features_proj'].to(self.device), feats['labels'].to(self.device), norm_type=getattr(self.model,'norm_type','l1'))
|
| 380 |
+
print("Face-usage heatmap computed [C,10] (display top 3 classes by mass):")
|
| 381 |
+
class_mass = heat.sum(dim=1)
|
| 382 |
+
top3 = torch.topk(class_mass, k=min(3, heat.size(0))).indices.tolist()
|
| 383 |
+
for c in top3:
|
| 384 |
+
print(f" class {c}: {heat[c].cpu().numpy().round(3)}")
|
| 385 |
+
|
| 386 |
+
res['linear_probe'] = self.test_linear_probe(feats)
|
| 387 |
+
res['pentachora'] = self.analyze_pentachora_usage()
|
| 388 |
+
|
| 389 |
+
# Visualizations
|
| 390 |
+
_, n_tsne = self.visualize_features(feats, 'tsne')
|
| 391 |
+
_, n_pca = self.visualize_features(feats, 'pca')
|
| 392 |
+
res['visual_clusters'] = {'tsne': n_tsne, 'pca': n_pca}
|
| 393 |
+
|
| 394 |
+
# Summary
|
| 395 |
+
print("\n" + "="*60); print("DIAGNOSIS SUMMARY"); print("="*60)
|
| 396 |
+
up = res['collapse']['unique_patterns']; lp = res['linear_probe']
|
| 397 |
+
if up <= 45 and lp <= 0.42:
|
| 398 |
+
print(f"🔴 Compact regime: {up} unique patterns; linear probe {lp*100:.1f}%")
|
| 399 |
+
elif up > 60:
|
| 400 |
+
print(f"✅ Diverse regime: {up} unique patterns; linear probe {lp*100:.1f}%")
|
| 401 |
+
else:
|
| 402 |
+
print(f"⚡ Partial bottleneck: {up} unique patterns; linear probe {lp*100:.1f}%")
|
| 403 |
+
|
| 404 |
+
return res
|
| 405 |
+
|
| 406 |
+
# ------------------------------
|
| 407 |
+
# Helpers (unchanged interface)
|
| 408 |
+
# ------------------------------
|
| 409 |
+
def _count_unique_patterns(self, features, method='elbow'):
|
| 410 |
+
X = features[:min(3000, len(features))]
|
| 411 |
+
if method == 'elbow':
|
| 412 |
+
inertias, K = [], list(range(20, 80, 5))
|
| 413 |
+
for k in K:
|
| 414 |
+
km = KMeans(n_clusters=k, random_state=42, n_init=3)
|
| 415 |
+
km.fit(X); inertias.append(km.inertia_)
|
| 416 |
+
diffs = np.diff(inertias); diffs2 = np.diff(diffs)
|
| 417 |
+
if len(diffs2) > 0:
|
| 418 |
+
elbow_idx = int(np.argmax(np.abs(diffs2))) + 1
|
| 419 |
+
est = K[elbow_idx]
|
| 420 |
+
else:
|
| 421 |
+
est = 41
|
| 422 |
+
else:
|
| 423 |
+
scores, K = [], list(range(30, 60, 5))
|
| 424 |
+
for k in K:
|
| 425 |
+
km = KMeans(n_clusters=k, random_state=42, n_init=3)
|
| 426 |
+
lbl = km.fit_predict(X)
|
| 427 |
+
scores.append(silhouette_score(X, lbl))
|
| 428 |
+
est = K[int(np.argmax(scores))]
|
| 429 |
+
return est
|
| 430 |
+
|
| 431 |
+
# ============================================
|
| 432 |
+
# QUICK TEST (RoseFace-aware)
|
| 433 |
+
# ============================================
|
| 434 |
+
|
| 435 |
+
def quick_41_percent_test(model, test_loader, device=None, apply_margin_eval=False):
|
| 436 |
+
"""
|
| 437 |
+
If apply_margin_eval=True, pass targets to model at eval (margin applied).
|
| 438 |
+
Otherwise, evaluate without margin (classic).
|
| 439 |
+
"""
|
| 440 |
+
print("="*60); print("41% ACCURACY CAP HYPOTHESIS TEST"); print("="*60)
|
| 441 |
+
model.eval()
|
| 442 |
+
device = device or next(model.parameters()).device
|
| 443 |
+
|
| 444 |
+
# 1) Accuracy
|
| 445 |
+
print("\n1. Verifying model accuracy...")
|
| 446 |
+
correct, total = 0, 0
|
| 447 |
+
with torch.no_grad():
|
| 448 |
+
for images, labels in tqdm(test_loader, desc="Testing"):
|
| 449 |
+
images = images.to(device)
|
| 450 |
+
labels = labels.to(device)
|
| 451 |
+
outputs = model(images, targets=labels) if apply_margin_eval else model(images)
|
| 452 |
+
pred = outputs['logits'].argmax(dim=1)
|
| 453 |
+
correct += (pred == labels).sum().item()
|
| 454 |
+
total += labels.size(0)
|
| 455 |
+
acc = 100 * correct / total
|
| 456 |
+
policy = "WITH margin" if apply_margin_eval else "NO margin"
|
| 457 |
+
print(f" Test Accuracy ({policy}): {acc:.2f}%")
|
| 458 |
+
|
| 459 |
+
is_at_cap = abs(acc - 41.0) < 3.0
|
| 460 |
+
# 2) Focused analysis (small sample)
|
| 461 |
+
print("\n2. Analyzing feature patterns...")
|
| 462 |
+
margin_mode = 'apply' if apply_margin_eval else 'none'
|
| 463 |
+
analyzer = FeatureAnalyzer(model, test_loader, device=device, margin_mode=margin_mode)
|
| 464 |
+
feats = analyzer.extract_all_features(max_batches=20)
|
| 465 |
+
acc_rose5 = offline_head_eval_rose5(model, feats['features_proj'].to(device), feats['labels'])
|
| 466 |
+
print(f"Offline prototype eval (rose5, no margin): {acc_rose5*100:.2f}%")
|
| 467 |
+
collapse = analyzer.analyze_feature_collapse(feats)
|
| 468 |
+
pent = analyzer.analyze_pentachora_usage()
|
| 469 |
+
|
| 470 |
+
print("\n" + "="*60); print("VERDICT"); print("="*60)
|
| 471 |
+
if is_at_cap and collapse['unique_patterns'] <= 45:
|
| 472 |
+
print("🔴 41% CAP CONFIRMED")
|
| 473 |
+
print(f" Acc: {acc:.1f}% | Unique patterns: {collapse['unique_patterns']}")
|
| 474 |
+
print(" Likely geometric bottleneck.")
|
| 475 |
+
elif collapse['unique_patterns'] <= 45:
|
| 476 |
+
print("⚠️ FEATURE BOTTLENECK DETECTED")
|
| 477 |
+
print(f" {collapse['unique_patterns']} patterns; Acc={acc:.1f}%")
|
| 478 |
+
else:
|
| 479 |
+
print("✅ NO 41% BOTTLENECK")
|
| 480 |
+
print(f" {collapse['unique_patterns']} patterns; Acc={acc:.1f}%")
|
| 481 |
+
|
| 482 |
+
return {
|
| 483 |
+
'accuracy': acc,
|
| 484 |
+
'unique_patterns': collapse['unique_patterns'],
|
| 485 |
+
'is_bottlenecked': collapse['unique_patterns'] <= 45,
|
| 486 |
+
'pentachora_similarity': pent['mean_similarity']
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
@torch.no_grad()
|
| 490 |
+
def offline_head_eval_rose5(model, features_proj, labels):
|
| 491 |
+
# compute z_l2 (dual-norm bridge)
|
| 492 |
+
z = features_proj
|
| 493 |
+
z_l2 = z / (z.norm(p=2, dim=-1, keepdim=True) + 1e-12)
|
| 494 |
+
# build rose5 prototypes [C,D]
|
| 495 |
+
V = torch.stack([p.vertices for p in model.class_pentachora], dim=0).to(z.device, z.dtype) # [C,5,D]
|
| 496 |
+
V = V / (V.norm(p=2, dim=-1, keepdim=True) + 1e-12)
|
| 497 |
+
W = model.rose_face_weights.to(z.device, z.dtype) # [10,5]
|
| 498 |
+
faces = torch.einsum('tf,cfd->ctd', W, V)
|
| 499 |
+
proto = (V.mean(dim=1) + 0.5 * faces.mean(dim=1))
|
| 500 |
+
proto = proto / (proto.norm(p=2, dim=-1, keepdim=True) + 1e-12) # [C,D]
|
| 501 |
+
cos = z_l2 @ proto.t() # [N,C]
|
| 502 |
+
acc = (cos.argmax(dim=1) == labels.to(z.device)).float().mean().item()
|
| 503 |
+
return acc
|
| 504 |
+
|
| 505 |
+
# ==========================
|
| 506 |
+
# RUN ANALYSIS (example)
|
| 507 |
+
# ==========================
|
| 508 |
+
|
| 509 |
+
if __name__ == "__main__":
|
| 510 |
+
print("Starting RoseFace-aware feature analysis...")
|
| 511 |
+
print(f"Model device: {next(model.parameters()).device}")
|
| 512 |
+
# Dataloaders
|
| 513 |
+
train_loader, test_loader, train_transforms = get_cifar100_dataloaders(batch_size=128)
|
| 514 |
+
|
| 515 |
+
# Quick test in BOTH modes (optional): compare accuracy
|
| 516 |
+
print("\nRunning quick 41% hypothesis test (NO margin at eval)...")
|
| 517 |
+
res_nom = quick_41_percent_test(model, test_loader, apply_margin_eval=False)
|
| 518 |
+
|
| 519 |
+
print("\nRunning quick 41% hypothesis test (WITH margin at eval)...")
|
| 520 |
+
res_mar = quick_41_percent_test(model, test_loader, apply_margin_eval=True)
|
| 521 |
+
|
| 522 |
+
# Full analysis with richer diagnostics (no margin at eval is typical)
|
| 523 |
+
analyzer = FeatureAnalyzer(model, test_loader, margin_mode='none')
|
| 524 |
+
full_results = analyzer.run_full_analysis()
|