geolip-constellation-core / analyze_weights.py
AbstractPhil's picture
Update analyze_weights.py
46f26dc verified
#!/usr/bin/env python3
"""
GeoLIP Core β€” Full Analysis + Sphere Visualizations
=====================================================
Auto-detects CIFAR-10 vs CIFAR-100 from checkpoint config.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import os
from collections import defaultdict
from torchvision import datasets, transforms
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CKPT = "checkpoints/geolip_core_best.pt"
OUT_DIR = "analysis_out"
BATCH = 256
# ── HuggingFace push ──
HF_REPO_ID = "AbstractPhil/geolip-constellation-core"
HF_PUSH = True
CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR_STD = (0.2470, 0.2435, 0.2616)
CIFAR10_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
os.makedirs(OUT_DIR, exist_ok=True)
print("=" * 70)
print("GEOLIP CORE β€” ANALYSIS + SPHERE VISUALIZATIONS")
print(f" Checkpoint: {CKPT}")
print(f" Output: {OUT_DIR}/")
print("=" * 70)
# ══════════════════════════════════════════════════════════════════
# LOAD β€” auto-detect dataset from config
# ══════════════════════════════════════════════════════════════════
ckpt = torch.load(CKPT, map_location="cpu", weights_only=False)
cfg = ckpt["config"]
N_CLASSES = cfg.get('num_classes', 10)
print(f" Epoch: {ckpt['epoch']} Val acc: {ckpt['val_acc']:.1f}%")
print(f" Config: output_dim={cfg.get('output_dim')}, "
f"n_anchors={cfg.get('n_anchors')}, "
f"n_comp={cfg.get('n_comp')}, d_comp={cfg.get('d_comp')}, "
f"num_classes={N_CLASSES}")
if N_CLASSES <= 10:
CLASS_NAMES = CIFAR10_CLASSES[:N_CLASSES]
ds_cls = datasets.CIFAR10
ds_name = "CIFAR-10"
else:
ds_cls = datasets.CIFAR100
ds_name = "CIFAR-100"
_tmp = datasets.CIFAR100(root='./data', train=False, download=True)
CLASS_NAMES = _tmp.classes
del _tmp
print(f" Dataset: {ds_name} ({N_CLASSES} classes)")
model = GeoLIPCore(**cfg).to(DEVICE)
model.load_state_dict(ckpt["state_dict"])
model.eval()
val_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
val_ds = ds_cls(root='./data', train=False, download=True, transform=val_transform)
val_loader = torch.utils.data.DataLoader(
val_ds, batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True)
total_params = sum(p.numel() for p in model.parameters())
# ══════════════════════════════════════════════════════════════════
# COLLECT ALL EMBEDDINGS + PREDICTIONS
# ══════════════════════════════════════════════════════════════════
print("\n Collecting embeddings...")
all_embs, all_tris, all_nearest, all_labels, all_preds, all_logits = [], [], [], [], [], []
with torch.no_grad():
for imgs, lbls in val_loader:
imgs = imgs.to(DEVICE)
out = model(imgs)
all_embs.append(out['embedding'].float().cpu())
all_tris.append(out['triangulation'].float().cpu())
all_nearest.append(out['nearest'].cpu())
all_labels.append(lbls)
all_preds.append(out['logits'].argmax(-1).cpu())
all_logits.append(out['logits'].float().cpu())
embs = torch.cat(all_embs)
tris = torch.cat(all_tris)
nearest = torch.cat(all_nearest)
labels = torch.cat(all_labels)
preds = torch.cat(all_preds)
logits = torch.cat(all_logits)
embs_n = F.normalize(embs, dim=-1)
val_acc = (preds == labels).float().mean().item() * 100
print(f" Val accuracy: {val_acc:.1f}%")
print(f" Embeddings: {embs.shape}")
# ══════════════════════════════════════════════════════════════════
# ANCHOR PUSH β€” drag anchors to where the data lives
# ══════════════════════════════════════════════════════════════════
N_PUSH_STEPS = 30
PUSH_LR = 0.5
print(f"\n Pushing anchors toward CLASS centroids ({N_PUSH_STEPS} steps, lr={PUSH_LR})...")
# Before stats
anchors_before = model.constellation.anchors.detach().float().cpu().clone()
anch_n_before = F.normalize(anchors_before, dim=-1)
cos_before = (embs_n @ anch_n_before.T).max(dim=1).values.mean().item()
print(f" Before: mean nearest_cos = {cos_before:.4f}")
# Push using class centroids
emb_device = embs.to(DEVICE)
lbl_device = labels.to(DEVICE)
if hasattr(model, 'push_anchors_to_centroids'):
for step in range(N_PUSH_STEPS):
moved = model.push_anchors_to_centroids(emb_device, lbl_device, lr=PUSH_LR)
if (step + 1) % 10 == 0:
an_tmp = F.normalize(model.constellation.anchors.detach().float().cpu(), dim=-1)
c_tmp = (embs_n @ an_tmp.T).max(dim=1).values.mean().item()
print(f" Step {step+1:3d}: nearest_cos = {c_tmp:.4f}, moved = {moved}")
else:
# Inline class-centroid push
with torch.no_grad():
anchors_param = model.constellation.anchors.data
emb_dev = F.normalize(emb_device, dim=-1)
# Compute class centroids once
classes = lbl_device.unique()
n_cls = classes.shape[0]
centroids = []
for c in classes:
mask = lbl_device == c
centroids.append(F.normalize(emb_dev[mask].mean(0, keepdim=True), dim=-1))
centroids = torch.cat(centroids, dim=0) # (C, D)
# Assign anchors to classes round-robin
n_a = anchors_param.shape[0]
anchors_per_class = n_a // n_cls
for step in range(N_PUSH_STEPS):
an = F.normalize(anchors_param, dim=-1)
cos_ac = an @ centroids.T # (A, C)
# Greedy assign
assigned = torch.full((n_a,), -1, dtype=torch.long, device=DEVICE)
cls_count = torch.zeros(n_cls, dtype=torch.long, device=DEVICE)
_, flat_idx = cos_ac.flatten().sort(descending=True)
for idx in flat_idx:
a = (idx // n_cls).item()
c_idx = (idx % n_cls).item()
if assigned[a] >= 0: continue
if cls_count[c_idx] >= anchors_per_class + 1: continue
assigned[a] = c_idx
cls_count[c_idx] += 1
if (assigned >= 0).all(): break
unassigned = (assigned < 0).nonzero(as_tuple=True)[0]
if len(unassigned) > 0:
assigned[unassigned] = (an[unassigned] @ centroids.T).argmax(dim=1)
# Push each anchor toward its class centroid
for a in range(n_a):
target = centroids[assigned[a].item()]
rank = (assigned[:a] == assigned[a]).sum().item()
if rank > 0:
noise = torch.randn_like(target) * 0.05
noise = noise - (noise * target).sum() * target
target = F.normalize((target + noise).unsqueeze(0), dim=-1).squeeze(0)
anchors_param[a] = F.normalize(
(an[a] + PUSH_LR * (target - an[a])).unsqueeze(0), dim=-1).squeeze(0)
if (step + 1) % 10 == 0:
an_tmp = F.normalize(anchors_param, dim=-1)
c_tmp = (emb_dev @ an_tmp.T).max(dim=1).values.mean().item()
print(f" Step {step+1:3d}: nearest_cos = {c_tmp:.4f}")
# After stats
anchors = model.constellation.anchors.detach().float().cpu()
anchors_n = F.normalize(anchors, dim=-1)
n_anchors = anchors.shape[0]
cos_after = (embs_n @ anchors_n.T).max(dim=1).values.mean().item()
drift = (F.normalize(anchors_before, dim=-1) - anchors_n).norm(dim=-1).mean().item()
print(f" After: mean nearest_cos = {cos_after:.4f} (Ξ”={cos_after - cos_before:+.4f})")
print(f" Anchor drift: {drift:.4f}")
# Re-triangulate with pushed anchors
with torch.no_grad():
new_cos = embs_n @ anchors_n.T
tris = 1.0 - new_cos
nearest = new_cos.argmax(dim=1)
print(f" Anchors: {anchors.shape}")
# ══════════════════════════════════════════════════════════════════
# AUDIT 1: NUMERIC HEALTH
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*70}")
print("AUDIT 1: NUMERIC HEALTH")
print(f"{'='*70}")
issues = []
for name, param in model.named_parameters():
p = param.detach().float()
n_nan = torch.isnan(p).sum().item()
n_inf = torch.isinf(p).sum().item()
p_std = p.std().item() if p.numel() > 1 else 0
flags = []
if n_nan > 0: flags.append(f"NaN={n_nan}")
if n_inf > 0: flags.append(f"inf={n_inf}")
if p_std < 1e-8 and p.numel() > 1: flags.append(f"COLLAPSED(std={p_std:.2e})")
if flags:
print(f" ⚠ {name:<50} {' '.join(flags)}")
issues.append(name)
if not issues:
print(f" βœ“ All {total_params:,} parameters clean")
# ══════════════════════════════════════════════════════════════════
# AUDIT 2: PER-CLASS ACCURACY
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*70}")
print("AUDIT 2: PER-CLASS ACCURACY")
print(f"{'='*70}")
class_accs = []
for c in range(N_CLASSES):
mask = labels == c
acc = (preds[mask] == c).float().mean().item() * 100 if mask.sum() > 0 else 0
class_accs.append(acc)
if N_CLASSES <= 10:
for c in range(N_CLASSES):
print(f" {CLASS_NAMES[c]:<12}: {class_accs[c]:5.1f}%")
else:
sorted_idx = sorted(range(N_CLASSES), key=lambda c: class_accs[c])
print(f" Bottom 10:")
for c in sorted_idx[:10]:
print(f" {CLASS_NAMES[c]:<20}: {class_accs[c]:5.1f}%")
print(f" Top 10:")
for c in sorted_idx[-10:]:
print(f" {CLASS_NAMES[c]:<20}: {class_accs[c]:5.1f}%")
print(f" Mean: {np.mean(class_accs):.1f}% "
f"Median: {np.median(class_accs):.1f}% "
f"Std: {np.std(class_accs):.1f}%")
# ══════════════════════════════════════════════════════════════════
# AUDIT 3: EMBEDDING SPACE
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*70}")
print("AUDIT 3: EMBEDDING SPACE")
print(f"{'='*70}")
n_sample = min(2000, len(embs))
sim = embs_n[:n_sample] @ embs_n[:n_sample].T
sim_mask = ~torch.eye(n_sample, dtype=torch.bool)
labels_s = labels[:n_sample]
same_class = labels_s.unsqueeze(0) == labels_s.unsqueeze(1)
same_not_self = same_class & sim_mask
diff_class = ~same_class & sim_mask
self_sim = sim[sim_mask].mean().item()
same_cos = sim[same_not_self].mean().item() if same_not_self.any() else 0
diff_cos = sim[diff_class].mean().item() if diff_class.any() else 0
gap = same_cos - diff_cos
_, S, _ = torch.linalg.svd(embs_n[:512].float(), full_matrices=False)
p = S / S.sum()
eff_dim = p.pow(2).sum().reciprocal().item()
print(f" Self-similarity: {self_sim:.4f}")
print(f" Same-class cos: {same_cos:.4f}")
print(f" Diff-class cos: {diff_cos:.4f}")
print(f" Gap: {gap:.4f}")
print(f" Effective dim: {eff_dim:.1f}/{embs.shape[1]}")
# ══════════════════════════════════════════════════════════════════
# AUDIT 4: CONSTELLATION HEALTH
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*70}")
print("AUDIT 4: CONSTELLATION HEALTH")
print(f"{'='*70}")
anch_sim = anchors_n @ anchors_n.T
anch_mask = ~torch.eye(n_anchors, dtype=torch.bool)
anch_off = anch_sim[anch_mask]
n_active = nearest.unique().numel()
counts = torch.zeros(n_anchors, dtype=torch.long)
for a in range(n_anchors):
counts[a] = (nearest == a).sum()
print(f" Anchors: {n_anchors} Γ— {anchors.shape[1]}")
print(f" Pairwise cos: mean={anch_off.mean():.4f} max={anch_off.max():.4f}")
print(f" Active: {n_active}/{n_anchors}")
print(f" Utilization: min={counts.min().item()} max={counts.max().item()} "
f"mean={counts.float().mean():.1f} std={counts.float().std():.1f}")
# ══════════════════════════════════════════════════════════════════
# AUDIT 5: PENTACHORON CV
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*70}")
print("AUDIT 5: PENTACHORON CV")
print(f"{'='*70}")
sample = embs_n[:2000].to(DEVICE)
vols = []
with torch.no_grad():
for _ in range(500):
idx = torch.randperm(min(2000, len(sample)), device=DEVICE)[:5]
pts = sample[idx].unsqueeze(0).float()
gram = torch.bmm(pts, pts.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
d2 = F.relu(d2)
cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
v2 = -torch.linalg.det(cm) / 9216
if v2[0].item() > 1e-20:
vols.append(v2[0].sqrt().cpu())
if len(vols) > 10:
vt = torch.stack(vols)
v_cv = (vt.std() / (vt.mean() + 1e-8)).item()
band = "βœ“ IN BAND" if 0.18 <= v_cv <= 0.25 else "βœ— outside"
print(f" CV: {v_cv:.4f} ({band})")
print(f" Vol mean: {vt.mean():.6f} std: {vt.std():.6f}")
else:
v_cv = 0
print(f" ⚠ Not enough valid pentachora ({len(vols)})")
# ══════════════════════════════════════════════════════════════════
# AUDIT 6: CONFIDENCE CALIBRATION
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*70}")
print("AUDIT 6: CONFIDENCE CALIBRATION")
print(f"{'='*70}")
probs = logits.softmax(-1)
conf = probs.max(dim=1).values
correct_mask = preds == labels
print(f" Correct: mean_conf={conf[correct_mask].mean():.4f} "
f"std={conf[correct_mask].std():.4f}")
if (~correct_mask).any():
wrong_conf = conf[~correct_mask]
overconf = (wrong_conf > 0.9).sum().item()
print(f" Wrong: mean_conf={wrong_conf.mean():.4f} "
f"std={wrong_conf.std():.4f}")
print(f" Overconfident wrong (>0.9): {overconf}/{wrong_conf.numel()} "
f"({100*overconf/max(wrong_conf.numel(),1):.1f}%)")
# ══════════════════════════════════════════════════════════════════
# AUDIT 7: GRADIENT FLOW
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*70}")
print("AUDIT 7: GRADIENT FLOW")
print(f"{'='*70}")
model.train()
model.zero_grad()
imgs_g, lbls_g = next(iter(val_loader))
imgs_g = imgs_g[:16].to(DEVICE)
lbls_g = lbls_g[:16].to(DEVICE)
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
out = model(imgs_g)
loss = F.cross_entropy(out['logits'], lbls_g) + 0.1 * out['embedding'].mean()
loss.backward()
grad_by_mod = defaultdict(list)
for name, param in model.named_parameters():
if param.grad is None: continue
gn = param.grad.detach().float().norm().item()
if "encoder" in name: mod = "encoder"
elif "constellation" in name: mod = "constellation"
elif "patchwork" in name: mod = "patchwork"
elif "classifier" in name: mod = "classifier"
else: mod = "other"
grad_by_mod[mod].append(gn)
for mod in sorted(grad_by_mod):
norms = grad_by_mod[mod]
print(f" {mod:<15}: mean={np.mean(norms):.6f} max={np.max(norms):.6f} "
f"({len(norms)} params)")
print(f" βœ“ All parameters receive gradient")
model.eval()
# ══════════════════════════════════════════════════════════════════
# VISUALIZATIONS
# ══════════════════════════════════════════════════════════════════
try:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
HAS_PLT = True
except ImportError:
HAS_PLT = False
print("\n ⚠ matplotlib not available, skipping visualizations")
if HAS_PLT:
if N_CLASSES <= 10:
CLASS_COLORS = [
'#e6194b', '#3cb44b', '#4363d8', '#f58231', '#911eb4',
'#42d4f4', '#f032e6', '#bfef45', '#469990', '#dcbeff']
else:
# Vibrant HSV spiral β€” 100 distinct saturated colors
import colorsys
CLASS_COLORS = []
for i in range(N_CLASSES):
# Golden angle rotation for max hue separation
hue = (i * 0.618033988749895) % 1.0
# Alternate saturation/value for neighboring hues
sat = 0.75 + 0.25 * (i % 3) / 2
val = 0.85 + 0.15 * ((i + 1) % 2)
r, g, b = colorsys.hsv_to_rgb(hue, sat, val)
CLASS_COLORS.append(f'#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}')
# Dark theme for all plots β€” makes colors pop
plt.style.use('dark_background')
plt.rcParams.update({
'figure.facecolor': '#1a1a2e',
'axes.facecolor': '#16213e',
'axes.edgecolor': '#444466',
'axes.labelcolor': '#e0e0e0',
'text.color': '#e0e0e0',
'xtick.color': '#aaaacc',
'ytick.color': '#aaaacc',
'grid.color': '#333355',
'legend.facecolor': '#1a1a2e',
'legend.edgecolor': '#444466',
})
print(f"\n{'='*70}")
print("VISUALIZATIONS")
print(f"{'='*70}")
def save_fig(filename, dpi=200):
plt.savefig(f'{OUT_DIR}/{filename}', dpi=dpi)
# ── Sphere grid helpers ──
def draw_sphere_grid_2d(ax, radius, n_meridians=24):
"""Draw sphere reference grid β€” UNMISSABLE."""
print(f" >>> DRAWING 2D GRID: radius={radius:.4f}, lw=5, white+cyan")
theta = np.linspace(0, 2 * np.pi, 500)
xr = radius * np.cos(theta)
yr = radius * np.sin(theta)
# Cyan glow (fat, behind)
ax.plot(xr, yr, color='#00e5ff', alpha=0.6, lw=9, zorder=49)
# White ring on top
ax.plot(xr, yr, color='white', alpha=1.0, lw=5, zorder=50,
solid_capstyle='round')
# Inner rings β€” dashed cyan, thick
for frac in [0.5, 0.75]:
ax.plot(frac * xr, frac * yr,
color='#00e5ff', alpha=0.5, lw=2, linestyle='--', zorder=50)
# Meridian ticks β€” chunky white
for i in range(n_meridians):
a = 2 * np.pi * i / n_meridians
r0, r1 = radius * 0.92, radius * 1.08
ax.plot([r0*np.cos(a), r1*np.cos(a)],
[r0*np.sin(a), r1*np.sin(a)],
color='white', alpha=0.8, lw=2, zorder=50)
# Crosshairs
s = radius * 1.15
ax.plot([-s, s], [0, 0], color='#00e5ff', alpha=0.3, lw=1.5, zorder=49)
ax.plot([0, 0], [-s, s], color='#00e5ff', alpha=0.3, lw=1.5, zorder=49)
# Text label proving it rendered
ax.text(radius * 0.72, radius * 0.72, f'r={radius:.2f}',
color='#00e5ff', fontsize=10, fontweight='bold',
alpha=0.9, zorder=51)
def draw_sphere_grid_3d(ax, radius, n_lines=16):
"""Draw a wireframe sphere in 3D PCA space β€” THICK."""
print(f" >>> DRAWING 3D WIREFRAME: radius={radius:.4f}, lw=1.2+3")
theta = np.linspace(0, 2 * np.pi, 80)
phi = np.linspace(0, np.pi, 40)
# Latitude rings
for p in np.linspace(0, np.pi, n_lines + 1)[1:-1]:
r = radius * np.sin(p)
z = radius * np.cos(p)
ax.plot(r * np.cos(theta), r * np.sin(theta),
z * np.ones_like(theta),
color='white', alpha=0.4, lw=1.2)
# Longitude meridians
for t in np.linspace(0, 2 * np.pi, n_lines, endpoint=False):
x = radius * np.sin(phi) * np.cos(t)
y = radius * np.sin(phi) * np.sin(t)
z = radius * np.cos(phi)
ax.plot(x, y, z, color='white', alpha=0.4, lw=1.2)
# Equator β€” bright cyan, extra thick
ax.plot(radius * np.cos(theta), radius * np.sin(theta),
np.zeros_like(theta), color='#00e5ff', alpha=0.9, lw=3)
# PCA basis
embs_c = embs_n[:5000] - embs_n[:5000].mean(0, keepdim=True)
_, _, Vt = torch.linalg.svd(embs_c, full_matrices=False)
proj_2d = (embs_n @ Vt[:2].T).numpy()
proj_3d = (embs_n @ Vt[:3].T).numpy()
anch_2d = (anchors_n @ Vt[:2].T).numpy()
anch_3d = (anchors_n @ Vt[:3].T).numpy()
proj_labels = labels.numpy()
# Compute sphere radius from projected data
emb_radii_2d = np.sqrt(proj_2d[:5000, 0]**2 + proj_2d[:5000, 1]**2)
sphere_r_2d = np.percentile(emb_radii_2d, 95)
emb_radii_3d = np.sqrt((proj_3d[:3000]**2).sum(axis=1))
sphere_r_3d = np.percentile(emb_radii_3d, 95)
# Sanity: if projections are tiny, use data range instead
data_range_2d = max(np.abs(proj_2d[:5000]).max(), np.abs(anch_2d).max())
data_range_3d = max(np.abs(proj_3d[:3000]).max(), np.abs(anch_3d).max())
if sphere_r_2d < 0.01:
sphere_r_2d = data_range_2d * 0.9
if sphere_r_3d < 0.01:
sphere_r_3d = data_range_3d * 0.9
print(f" Sphere radius (2D): {sphere_r_2d:.4f} (3D): {sphere_r_3d:.4f}")
print(f" Data range (2D): {data_range_2d:.4f} (3D): {data_range_3d:.4f}")
# ── [1] PCA embedding space ──
print(" [1/8] PCA projection...")
fig, ax = plt.subplots(1, 1, figsize=(12, 10))
for c in range(N_CLASSES):
mask = proj_labels[:5000] == c
if mask.sum() == 0: continue
lbl = CLASS_NAMES[c] if N_CLASSES <= 20 else None
ax.scatter(proj_2d[:5000][mask, 0], proj_2d[:5000][mask, 1],
c=CLASS_COLORS[c], s=4, alpha=0.5, label=lbl, zorder=2)
ax.scatter(anch_2d[:, 0], anch_2d[:, 1],
c='#FFD700', s=60, marker='*', edgecolors='white', linewidths=0.3, zorder=5, label='anchors')
# Grid drawn LAST β€” on top of everything
draw_sphere_grid_2d(ax, sphere_r_2d)
if N_CLASSES <= 20:
ax.legend(fontsize=7, markerscale=2, loc='upper right', ncol=2)
ax.set_title(f'GeoLIP Core β€” PCA Embedding Space ({ds_name})\n'
f'val={val_acc:.1f}% | {total_params:,} params | '
f'CV={v_cv:.4f} | {n_active}/{n_anchors} anchors', fontsize=11)
ax.set_xlabel('PC1'); ax.set_ylabel('PC2')
ax.set_aspect('equal')
ax.grid(True, alpha=0.15, color='#555577')
plt.tight_layout()
save_fig('01_pca_embedding_space.png')
plt.close()
# ── [2] Triangulation connections ──
print(" [2/8] Triangulation connections...")
fig, ax = plt.subplots(1, 1, figsize=(12, 10))
subset = min(500, len(embs))
for i in range(subset):
a_idx = nearest[i].item()
ax.plot([proj_2d[i, 0], anch_2d[a_idx, 0]],
[proj_2d[i, 1], anch_2d[a_idx, 1]],
c=CLASS_COLORS[labels[i].item()], alpha=0.1, linewidth=0.5)
for c in range(N_CLASSES):
mask = proj_labels[:5000] == c
if mask.sum() == 0: continue
ax.scatter(proj_2d[:5000][mask, 0], proj_2d[:5000][mask, 1],
c=CLASS_COLORS[c], s=5, alpha=0.4, zorder=2)
ax.scatter(anch_2d[:, 0], anch_2d[:, 1],
c='#FFD700', s=80, marker='*', edgecolors='white', linewidths=0.3, zorder=5)
if n_anchors <= 128:
for a in range(n_anchors):
a_mask = nearest == a
if a_mask.sum() > 0:
dom_class = labels[a_mask].mode().values.item()
ax.annotate(str(dom_class), (anch_2d[a, 0], anch_2d[a, 1]),
fontsize=4, ha='center', va='center',
color='white', fontweight='bold',
bbox=dict(boxstyle='round,pad=0.1',
fc=CLASS_COLORS[dom_class],
ec='#FFD700', linewidth=0.5,
alpha=0.85))
# Grid drawn LAST
draw_sphere_grid_2d(ax, sphere_r_2d)
ax.set_title(f'Triangulation: Image β†’ Nearest Anchor ({ds_name})', fontsize=11)
ax.set_aspect('equal')
ax.grid(True, alpha=0.15, color='#555577')
plt.tight_layout()
save_fig('02_triangulation_connections.png')
plt.close()
# ── [3] 3D sphere ──
print(" [3/8] 3D sphere projection...")
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')
n_3d = min(3000, len(embs))
for c in range(min(N_CLASSES, 20)):
mask = proj_labels[:n_3d] == c
if mask.sum() == 0: continue
ax.scatter(proj_3d[:n_3d][mask, 0], proj_3d[:n_3d][mask, 1],
proj_3d[:n_3d][mask, 2],
c=CLASS_COLORS[c], s=5, alpha=0.4,
label=CLASS_NAMES[c] if N_CLASSES <= 20 else None)
ax.scatter(anch_3d[:, 0], anch_3d[:, 1], anch_3d[:, 2],
c='#FFD700', s=40, marker='*', edgecolors='white', linewidths=0.3, zorder=5)
# Wireframe drawn AFTER data β€” 3D has no zorder, draw order is render order
draw_sphere_grid_3d(ax, sphere_r_3d)
if N_CLASSES <= 20:
ax.legend(fontsize=6, markerscale=2, loc='upper left', ncol=2)
ax.set_title(f'3D PCA β€” Constellation on the Sphere\n'
f'{n_anchors} anchors, {N_CLASSES} classes', fontsize=11)
try:
ax.set_box_aspect([1, 1, 1])
except AttributeError:
pass # older matplotlib
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
plt.tight_layout()
save_fig('03_3d_sphere.png')
plt.close()
# ── [4] Anchor-Class heatmap ──
print(" [4/8] Anchor-class assignment matrix...")
assign_mat = torch.zeros(N_CLASSES, n_anchors)
for c in range(N_CLASSES):
c_nearest = nearest[labels == c]
for a in range(n_anchors):
assign_mat[c, a] = (c_nearest == a).sum().float()
assign_norm = assign_mat / (assign_mat.sum(dim=1, keepdim=True) + 1e-8)
peak_class = assign_norm.argmax(dim=0)
sort_order = peak_class.argsort()
assign_sorted = assign_norm[:, sort_order]
h = max(6, N_CLASSES * 0.12)
fig, ax = plt.subplots(1, 1, figsize=(16, h))
im = ax.imshow(assign_sorted.numpy(), aspect='auto', cmap='inferno')
if N_CLASSES <= 30:
ax.set_yticks(range(N_CLASSES))
ax.set_yticklabels(CLASS_NAMES, fontsize=max(4, 9 - N_CLASSES // 15))
ax.set_xlabel('Anchor index (sorted by peak class)')
ax.set_title(f'Class β†’ Anchor Assignment ({ds_name})', fontsize=11)
plt.colorbar(im, ax=ax, shrink=0.8)
plt.tight_layout()
save_fig('04_anchor_class_heatmap.png')
plt.close()
# ── [5] Triangulation profiles ──
print(" [5/8] Class triangulation profiles...")
if N_CLASSES <= 10:
show_classes = list(range(N_CLASSES))
else:
sorted_by_acc = sorted(range(N_CLASSES), key=lambda c: class_accs[c])
show_classes = sorted_by_acc[:5] + sorted_by_acc[-5:]
nrows, ncols = 2, 5
fig, axes = plt.subplots(nrows, ncols, figsize=(20, 8))
for idx, c in enumerate(show_classes):
ax = axes[idx // ncols][idx % ncols]
c_tris = tris[labels == c]
if len(c_tris) == 0: continue
mean_tri = c_tris.mean(0).numpy()
std_tri = c_tris.std(0).numpy()
x = np.arange(n_anchors)
color = CLASS_COLORS[c]
ax.fill_between(x, mean_tri - std_tri, mean_tri + std_tri,
alpha=0.3, color=color)
ax.plot(x, mean_tri, color=color, linewidth=1.5)
ax.set_title(f'{CLASS_NAMES[c]} ({class_accs[c]:.0f}%)',
fontsize=9, fontweight='bold', color=color)
ax.set_ylim(0, max(1.6, mean_tri.max() * 1.2))
ax.tick_params(labelsize=5)
tag = "all classes" if N_CLASSES <= 10 else "5 worst + 5 best"
plt.suptitle(f'Triangulation Fingerprints ({tag})', fontsize=12)
plt.tight_layout()
save_fig('05_triangulation_profiles.png')
plt.close()
# ── [6] Anchor utilization ──
print(" [6/8] Anchor utilization...")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
sorted_counts, _ = counts.sort(descending=True)
ax1.bar(range(n_anchors), sorted_counts.numpy(),
color=['#00BCD4' if c > 0 else '#FF5252' for c in sorted_counts], width=1.0)
ax1.set_xlabel('Anchor (sorted)')
ax1.set_ylabel('Assigned samples')
ax1.set_title(f'Anchor Utilization ({n_active}/{n_anchors} active)')
ax1.axhline(y=len(labels) / n_anchors, color='#888899', linestyle='--', alpha=0.5)
# Per-class anchor entropy
entropies = []
for c in range(N_CLASSES):
c_nearest = nearest[labels == c]
dist = torch.zeros(n_anchors)
for a in range(n_anchors):
dist[a] = (c_nearest == a).sum().float()
dist = dist / (dist.sum() + 1e-8)
ent = -(dist * (dist + 1e-10).log()).sum().item()
entropies.append(ent)
if N_CLASSES <= 20:
ax2.barh(range(N_CLASSES), entropies,
color=[CLASS_COLORS[c] for c in range(N_CLASSES)])
ax2.set_yticks(range(N_CLASSES))
ax2.set_yticklabels(CLASS_NAMES, fontsize=8)
ax2.set_xlabel('Anchor assignment entropy')
else:
ax2.hist(entropies, bins=30, color='#00BCD4', edgecolor='#333355')
ax2.set_xlabel('Anchor assignment entropy')
ax2.set_ylabel('Number of classes')
# Gini
c_sorted = counts.float().sort().values
cum = c_sorted.cumsum(0)
gini = (1 - 2 * cum.sum() / (len(c_sorted) * c_sorted.sum() + 1e-8)).item()
ax2.set_title(f'Anchor Spread (Gini={gini:.3f})')
plt.tight_layout()
save_fig('06_anchor_utilization.png')
plt.close()
# ── [7] Patchwork compartment responses ──
print(" [7/8] Patchwork compartment responses...")
n_comp = cfg.get('n_comp', 8)
asgn = model.patchwork.asgn.cpu()
if N_CLASSES <= 10:
show_c = list(range(N_CLASSES))
else:
show_c = show_classes
ncols_pw = min(4, n_comp)
nrows_pw = math.ceil(n_comp / ncols_pw)
fig, axes = plt.subplots(nrows_pw, ncols_pw, figsize=(4 * ncols_pw, 3 * nrows_pw))
if n_comp == 1: axes = [[axes]]
elif nrows_pw == 1: axes = [axes if isinstance(axes, list) else list(axes)]
elif ncols_pw == 1: axes = [[a] for a in axes]
axes_flat = [axes[r][c] for r in range(nrows_pw) for c in range(ncols_pw)]
for k in range(min(n_comp, len(axes_flat))):
ax = axes_flat[k]
comp_tris = tris[:, asgn == k]
class_means = []
class_labels_show = []
for c in show_c:
cm = comp_tris[labels == c]
if len(cm) > 0:
class_means.append(cm.mean(0).numpy())
class_labels_show.append(CLASS_NAMES[c])
if not class_means: continue
class_means = np.stack(class_means)
ax.imshow(class_means, aspect='auto', cmap='plasma')
ax.set_yticks(range(len(class_labels_show)))
ax.set_yticklabels(class_labels_show, fontsize=6)
ax.set_title(f'Comp {k}', fontsize=9)
for k in range(n_comp, len(axes_flat)):
axes_flat[k].set_visible(False)
plt.suptitle('Patchwork Compartment Responses by Class', fontsize=12)
plt.tight_layout()
save_fig('07_patchwork_compartments.png')
plt.close()
# ── [8] Confusion matrix ──
print(" [8/8] Confusion matrix...")
conf_mat = torch.zeros(N_CLASSES, N_CLASSES, dtype=torch.long)
for i in range(len(labels)):
conf_mat[labels[i], preds[i]] += 1
conf_pct = conf_mat.float() / (conf_mat.sum(dim=1, keepdim=True) + 1e-8) * 100
if N_CLASSES <= 20:
fig, ax = plt.subplots(1, 1, figsize=(8, 7))
im = ax.imshow(conf_pct.numpy(), cmap='magma', vmin=0, vmax=100)
for i in range(N_CLASSES):
for j in range(N_CLASSES):
v = conf_pct[i, j].item()
ax.text(j, i, f'{v:.0f}', ha='center', va='center',
fontsize=max(4, 8 - N_CLASSES // 5),
color='black' if v > 60 else '#e0e0e0')
ax.set_xticks(range(N_CLASSES))
ax.set_yticks(range(N_CLASSES))
ax.set_xticklabels(CLASS_NAMES, rotation=45, ha='right', fontsize=7)
ax.set_yticklabels(CLASS_NAMES, fontsize=7)
else:
fig, ax = plt.subplots(1, 1, figsize=(14, 12))
im = ax.imshow(conf_pct.numpy(), cmap='magma', vmin=0, vmax=100)
ax.set_xlabel('Predicted class')
ax.set_ylabel('True class')
ax.set_title(f'Confusion Matrix β€” {val_acc:.1f}% ({ds_name})', fontsize=11)
plt.colorbar(im, ax=ax, shrink=0.8)
plt.tight_layout()
save_fig('08_confusion_matrix.png')
plt.close()
print(f"\n βœ“ All 8 visualizations saved to {OUT_DIR}/")
# ══════════════════════════════════════════════════════════════════
# SUMMARY
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*70}")
print("SUMMARY")
print(f"{'='*70}")
print(f" Dataset: {ds_name} ({N_CLASSES} classes)")
print(f" Params: {total_params:,}")
print(f" Val accuracy: {val_acc:.1f}%")
print(f" Eff dim: {eff_dim:.1f}/{embs.shape[1]}")
print(f" Same-class cos: {same_cos:.4f}")
print(f" Diff-class cos: {diff_cos:.4f}")
print(f" Gap: {gap:.4f}")
print(f" CV: {v_cv:.4f}")
print(f" Anchors active: {n_active}/{n_anchors}")
worst_i = min(range(N_CLASSES), key=lambda c: class_accs[c])
best_i = max(range(N_CLASSES), key=lambda c: class_accs[c])
print(f" Worst class: {CLASS_NAMES[worst_i]} ({class_accs[worst_i]:.1f}%)")
print(f" Best class: {CLASS_NAMES[best_i]} ({class_accs[best_i]:.1f}%)")
warnings = []
if n_active < n_anchors * 0.5:
warnings.append(f"Anchor collapse: {n_active}/{n_anchors}")
if eff_dim < 5:
warnings.append(f"Embedding collapse: eff_dim={eff_dim:.1f}")
if gap < 0.02:
warnings.append(f"Low class separation: gap={gap:.4f}")
if warnings:
print(f"\n ⚠ WARNINGS: {', '.join(warnings)}")
else:
print(f"\n βœ“ All diagnostics healthy")
print(f"\n{'='*70}")
print("ANALYSIS COMPLETE")
print(f"{'='*70}")
# ══════════════════════════════════════════════════════════════════
# PUSH IMAGES TO HUGGINGFACE
# ══════════════════════════════════════════════════════════════════
if HF_PUSH:
from huggingface_hub import upload_folder
print(f"\n Uploading {OUT_DIR}/ β†’ {HF_REPO_ID}/analysis/ ...")
upload_folder(
repo_id=HF_REPO_ID,
folder_path=OUT_DIR,
path_in_repo="analysis",
commit_message=f"Analysis: val={val_acc:.1f}% CV={v_cv:.4f} {n_active}/{n_anchors} anchors",
)
print(f" βœ“ Done: https://huggingface.co/{HF_REPO_ID}/tree/main/analysis")