| |
| """ |
| GeoLIP Core β Full Analysis + Sphere Visualizations |
| ===================================================== |
| Auto-detects CIFAR-10 vs CIFAR-100 from checkpoint config. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import math |
| import os |
| from collections import defaultdict |
| from torchvision import datasets, transforms |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| CKPT = "checkpoints/geolip_core_best.pt" |
| OUT_DIR = "analysis_out" |
| BATCH = 256 |
|
|
| |
| HF_REPO_ID = "AbstractPhil/geolip-constellation-core" |
| HF_PUSH = True |
|
|
| CIFAR_MEAN = (0.4914, 0.4822, 0.4465) |
| CIFAR_STD = (0.2470, 0.2435, 0.2616) |
|
|
| CIFAR10_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', |
| 'dog', 'frog', 'horse', 'ship', 'truck'] |
|
|
| os.makedirs(OUT_DIR, exist_ok=True) |
|
|
| print("=" * 70) |
| print("GEOLIP CORE β ANALYSIS + SPHERE VISUALIZATIONS") |
| print(f" Checkpoint: {CKPT}") |
| print(f" Output: {OUT_DIR}/") |
| print("=" * 70) |
|
|
| |
| |
| |
|
|
| ckpt = torch.load(CKPT, map_location="cpu", weights_only=False) |
| cfg = ckpt["config"] |
| N_CLASSES = cfg.get('num_classes', 10) |
| print(f" Epoch: {ckpt['epoch']} Val acc: {ckpt['val_acc']:.1f}%") |
| print(f" Config: output_dim={cfg.get('output_dim')}, " |
| f"n_anchors={cfg.get('n_anchors')}, " |
| f"n_comp={cfg.get('n_comp')}, d_comp={cfg.get('d_comp')}, " |
| f"num_classes={N_CLASSES}") |
|
|
| if N_CLASSES <= 10: |
| CLASS_NAMES = CIFAR10_CLASSES[:N_CLASSES] |
| ds_cls = datasets.CIFAR10 |
| ds_name = "CIFAR-10" |
| else: |
| ds_cls = datasets.CIFAR100 |
| ds_name = "CIFAR-100" |
| _tmp = datasets.CIFAR100(root='./data', train=False, download=True) |
| CLASS_NAMES = _tmp.classes |
| del _tmp |
|
|
| print(f" Dataset: {ds_name} ({N_CLASSES} classes)") |
|
|
| model = GeoLIPCore(**cfg).to(DEVICE) |
| model.load_state_dict(ckpt["state_dict"]) |
| model.eval() |
|
|
| val_transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize(CIFAR_MEAN, CIFAR_STD), |
| ]) |
| val_ds = ds_cls(root='./data', train=False, download=True, transform=val_transform) |
| val_loader = torch.utils.data.DataLoader( |
| val_ds, batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True) |
|
|
| total_params = sum(p.numel() for p in model.parameters()) |
|
|
| |
| |
| |
|
|
| print("\n Collecting embeddings...") |
| all_embs, all_tris, all_nearest, all_labels, all_preds, all_logits = [], [], [], [], [], [] |
|
|
| with torch.no_grad(): |
| for imgs, lbls in val_loader: |
| imgs = imgs.to(DEVICE) |
| out = model(imgs) |
| all_embs.append(out['embedding'].float().cpu()) |
| all_tris.append(out['triangulation'].float().cpu()) |
| all_nearest.append(out['nearest'].cpu()) |
| all_labels.append(lbls) |
| all_preds.append(out['logits'].argmax(-1).cpu()) |
| all_logits.append(out['logits'].float().cpu()) |
|
|
| embs = torch.cat(all_embs) |
| tris = torch.cat(all_tris) |
| nearest = torch.cat(all_nearest) |
| labels = torch.cat(all_labels) |
| preds = torch.cat(all_preds) |
| logits = torch.cat(all_logits) |
|
|
| embs_n = F.normalize(embs, dim=-1) |
| val_acc = (preds == labels).float().mean().item() * 100 |
| print(f" Val accuracy: {val_acc:.1f}%") |
| print(f" Embeddings: {embs.shape}") |
|
|
| |
| |
| |
|
|
| N_PUSH_STEPS = 30 |
| PUSH_LR = 0.5 |
|
|
| print(f"\n Pushing anchors toward CLASS centroids ({N_PUSH_STEPS} steps, lr={PUSH_LR})...") |
|
|
| |
| anchors_before = model.constellation.anchors.detach().float().cpu().clone() |
| anch_n_before = F.normalize(anchors_before, dim=-1) |
| cos_before = (embs_n @ anch_n_before.T).max(dim=1).values.mean().item() |
| print(f" Before: mean nearest_cos = {cos_before:.4f}") |
|
|
| |
| emb_device = embs.to(DEVICE) |
| lbl_device = labels.to(DEVICE) |
|
|
| if hasattr(model, 'push_anchors_to_centroids'): |
| for step in range(N_PUSH_STEPS): |
| moved = model.push_anchors_to_centroids(emb_device, lbl_device, lr=PUSH_LR) |
| if (step + 1) % 10 == 0: |
| an_tmp = F.normalize(model.constellation.anchors.detach().float().cpu(), dim=-1) |
| c_tmp = (embs_n @ an_tmp.T).max(dim=1).values.mean().item() |
| print(f" Step {step+1:3d}: nearest_cos = {c_tmp:.4f}, moved = {moved}") |
| else: |
| |
| with torch.no_grad(): |
| anchors_param = model.constellation.anchors.data |
| emb_dev = F.normalize(emb_device, dim=-1) |
|
|
| |
| classes = lbl_device.unique() |
| n_cls = classes.shape[0] |
| centroids = [] |
| for c in classes: |
| mask = lbl_device == c |
| centroids.append(F.normalize(emb_dev[mask].mean(0, keepdim=True), dim=-1)) |
| centroids = torch.cat(centroids, dim=0) |
|
|
| |
| n_a = anchors_param.shape[0] |
| anchors_per_class = n_a // n_cls |
|
|
| for step in range(N_PUSH_STEPS): |
| an = F.normalize(anchors_param, dim=-1) |
| cos_ac = an @ centroids.T |
|
|
| |
| assigned = torch.full((n_a,), -1, dtype=torch.long, device=DEVICE) |
| cls_count = torch.zeros(n_cls, dtype=torch.long, device=DEVICE) |
| _, flat_idx = cos_ac.flatten().sort(descending=True) |
| for idx in flat_idx: |
| a = (idx // n_cls).item() |
| c_idx = (idx % n_cls).item() |
| if assigned[a] >= 0: continue |
| if cls_count[c_idx] >= anchors_per_class + 1: continue |
| assigned[a] = c_idx |
| cls_count[c_idx] += 1 |
| if (assigned >= 0).all(): break |
| unassigned = (assigned < 0).nonzero(as_tuple=True)[0] |
| if len(unassigned) > 0: |
| assigned[unassigned] = (an[unassigned] @ centroids.T).argmax(dim=1) |
|
|
| |
| for a in range(n_a): |
| target = centroids[assigned[a].item()] |
| rank = (assigned[:a] == assigned[a]).sum().item() |
| if rank > 0: |
| noise = torch.randn_like(target) * 0.05 |
| noise = noise - (noise * target).sum() * target |
| target = F.normalize((target + noise).unsqueeze(0), dim=-1).squeeze(0) |
| anchors_param[a] = F.normalize( |
| (an[a] + PUSH_LR * (target - an[a])).unsqueeze(0), dim=-1).squeeze(0) |
|
|
| if (step + 1) % 10 == 0: |
| an_tmp = F.normalize(anchors_param, dim=-1) |
| c_tmp = (emb_dev @ an_tmp.T).max(dim=1).values.mean().item() |
| print(f" Step {step+1:3d}: nearest_cos = {c_tmp:.4f}") |
|
|
| |
| anchors = model.constellation.anchors.detach().float().cpu() |
| anchors_n = F.normalize(anchors, dim=-1) |
| n_anchors = anchors.shape[0] |
|
|
| cos_after = (embs_n @ anchors_n.T).max(dim=1).values.mean().item() |
| drift = (F.normalize(anchors_before, dim=-1) - anchors_n).norm(dim=-1).mean().item() |
| print(f" After: mean nearest_cos = {cos_after:.4f} (Ξ={cos_after - cos_before:+.4f})") |
| print(f" Anchor drift: {drift:.4f}") |
|
|
| |
| with torch.no_grad(): |
| new_cos = embs_n @ anchors_n.T |
| tris = 1.0 - new_cos |
| nearest = new_cos.argmax(dim=1) |
|
|
| print(f" Anchors: {anchors.shape}") |
|
|
| |
| |
| |
|
|
| print(f"\n{'='*70}") |
| print("AUDIT 1: NUMERIC HEALTH") |
| print(f"{'='*70}") |
|
|
| issues = [] |
| for name, param in model.named_parameters(): |
| p = param.detach().float() |
| n_nan = torch.isnan(p).sum().item() |
| n_inf = torch.isinf(p).sum().item() |
| p_std = p.std().item() if p.numel() > 1 else 0 |
| flags = [] |
| if n_nan > 0: flags.append(f"NaN={n_nan}") |
| if n_inf > 0: flags.append(f"inf={n_inf}") |
| if p_std < 1e-8 and p.numel() > 1: flags.append(f"COLLAPSED(std={p_std:.2e})") |
| if flags: |
| print(f" β {name:<50} {' '.join(flags)}") |
| issues.append(name) |
|
|
| if not issues: |
| print(f" β All {total_params:,} parameters clean") |
|
|
| |
| |
| |
|
|
| print(f"\n{'='*70}") |
| print("AUDIT 2: PER-CLASS ACCURACY") |
| print(f"{'='*70}") |
|
|
| class_accs = [] |
| for c in range(N_CLASSES): |
| mask = labels == c |
| acc = (preds[mask] == c).float().mean().item() * 100 if mask.sum() > 0 else 0 |
| class_accs.append(acc) |
|
|
| if N_CLASSES <= 10: |
| for c in range(N_CLASSES): |
| print(f" {CLASS_NAMES[c]:<12}: {class_accs[c]:5.1f}%") |
| else: |
| sorted_idx = sorted(range(N_CLASSES), key=lambda c: class_accs[c]) |
| print(f" Bottom 10:") |
| for c in sorted_idx[:10]: |
| print(f" {CLASS_NAMES[c]:<20}: {class_accs[c]:5.1f}%") |
| print(f" Top 10:") |
| for c in sorted_idx[-10:]: |
| print(f" {CLASS_NAMES[c]:<20}: {class_accs[c]:5.1f}%") |
| print(f" Mean: {np.mean(class_accs):.1f}% " |
| f"Median: {np.median(class_accs):.1f}% " |
| f"Std: {np.std(class_accs):.1f}%") |
|
|
| |
| |
| |
|
|
| print(f"\n{'='*70}") |
| print("AUDIT 3: EMBEDDING SPACE") |
| print(f"{'='*70}") |
|
|
| n_sample = min(2000, len(embs)) |
| sim = embs_n[:n_sample] @ embs_n[:n_sample].T |
| sim_mask = ~torch.eye(n_sample, dtype=torch.bool) |
| labels_s = labels[:n_sample] |
| same_class = labels_s.unsqueeze(0) == labels_s.unsqueeze(1) |
| same_not_self = same_class & sim_mask |
| diff_class = ~same_class & sim_mask |
|
|
| self_sim = sim[sim_mask].mean().item() |
| same_cos = sim[same_not_self].mean().item() if same_not_self.any() else 0 |
| diff_cos = sim[diff_class].mean().item() if diff_class.any() else 0 |
| gap = same_cos - diff_cos |
|
|
| _, S, _ = torch.linalg.svd(embs_n[:512].float(), full_matrices=False) |
| p = S / S.sum() |
| eff_dim = p.pow(2).sum().reciprocal().item() |
|
|
| print(f" Self-similarity: {self_sim:.4f}") |
| print(f" Same-class cos: {same_cos:.4f}") |
| print(f" Diff-class cos: {diff_cos:.4f}") |
| print(f" Gap: {gap:.4f}") |
| print(f" Effective dim: {eff_dim:.1f}/{embs.shape[1]}") |
|
|
| |
| |
| |
|
|
| print(f"\n{'='*70}") |
| print("AUDIT 4: CONSTELLATION HEALTH") |
| print(f"{'='*70}") |
|
|
| anch_sim = anchors_n @ anchors_n.T |
| anch_mask = ~torch.eye(n_anchors, dtype=torch.bool) |
| anch_off = anch_sim[anch_mask] |
| n_active = nearest.unique().numel() |
|
|
| counts = torch.zeros(n_anchors, dtype=torch.long) |
| for a in range(n_anchors): |
| counts[a] = (nearest == a).sum() |
|
|
| print(f" Anchors: {n_anchors} Γ {anchors.shape[1]}") |
| print(f" Pairwise cos: mean={anch_off.mean():.4f} max={anch_off.max():.4f}") |
| print(f" Active: {n_active}/{n_anchors}") |
| print(f" Utilization: min={counts.min().item()} max={counts.max().item()} " |
| f"mean={counts.float().mean():.1f} std={counts.float().std():.1f}") |
|
|
| |
| |
| |
|
|
| print(f"\n{'='*70}") |
| print("AUDIT 5: PENTACHORON CV") |
| print(f"{'='*70}") |
|
|
| sample = embs_n[:2000].to(DEVICE) |
| vols = [] |
| with torch.no_grad(): |
| for _ in range(500): |
| idx = torch.randperm(min(2000, len(sample)), device=DEVICE)[:5] |
| pts = sample[idx].unsqueeze(0).float() |
| gram = torch.bmm(pts, pts.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram |
| d2 = F.relu(d2) |
| cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=torch.float32) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| v2 = -torch.linalg.det(cm) / 9216 |
| if v2[0].item() > 1e-20: |
| vols.append(v2[0].sqrt().cpu()) |
|
|
| if len(vols) > 10: |
| vt = torch.stack(vols) |
| v_cv = (vt.std() / (vt.mean() + 1e-8)).item() |
| band = "β IN BAND" if 0.18 <= v_cv <= 0.25 else "β outside" |
| print(f" CV: {v_cv:.4f} ({band})") |
| print(f" Vol mean: {vt.mean():.6f} std: {vt.std():.6f}") |
| else: |
| v_cv = 0 |
| print(f" β Not enough valid pentachora ({len(vols)})") |
|
|
| |
| |
| |
|
|
| print(f"\n{'='*70}") |
| print("AUDIT 6: CONFIDENCE CALIBRATION") |
| print(f"{'='*70}") |
|
|
| probs = logits.softmax(-1) |
| conf = probs.max(dim=1).values |
| correct_mask = preds == labels |
|
|
| print(f" Correct: mean_conf={conf[correct_mask].mean():.4f} " |
| f"std={conf[correct_mask].std():.4f}") |
| if (~correct_mask).any(): |
| wrong_conf = conf[~correct_mask] |
| overconf = (wrong_conf > 0.9).sum().item() |
| print(f" Wrong: mean_conf={wrong_conf.mean():.4f} " |
| f"std={wrong_conf.std():.4f}") |
| print(f" Overconfident wrong (>0.9): {overconf}/{wrong_conf.numel()} " |
| f"({100*overconf/max(wrong_conf.numel(),1):.1f}%)") |
|
|
| |
| |
| |
|
|
| print(f"\n{'='*70}") |
| print("AUDIT 7: GRADIENT FLOW") |
| print(f"{'='*70}") |
|
|
| model.train() |
| model.zero_grad() |
| imgs_g, lbls_g = next(iter(val_loader)) |
| imgs_g = imgs_g[:16].to(DEVICE) |
| lbls_g = lbls_g[:16].to(DEVICE) |
|
|
| with torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| out = model(imgs_g) |
| loss = F.cross_entropy(out['logits'], lbls_g) + 0.1 * out['embedding'].mean() |
| loss.backward() |
|
|
| grad_by_mod = defaultdict(list) |
| for name, param in model.named_parameters(): |
| if param.grad is None: continue |
| gn = param.grad.detach().float().norm().item() |
| if "encoder" in name: mod = "encoder" |
| elif "constellation" in name: mod = "constellation" |
| elif "patchwork" in name: mod = "patchwork" |
| elif "classifier" in name: mod = "classifier" |
| else: mod = "other" |
| grad_by_mod[mod].append(gn) |
|
|
| for mod in sorted(grad_by_mod): |
| norms = grad_by_mod[mod] |
| print(f" {mod:<15}: mean={np.mean(norms):.6f} max={np.max(norms):.6f} " |
| f"({len(norms)} params)") |
| print(f" β All parameters receive gradient") |
| model.eval() |
|
|
|
|
| |
| |
| |
|
|
| try: |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| HAS_PLT = True |
| except ImportError: |
| HAS_PLT = False |
| print("\n β matplotlib not available, skipping visualizations") |
|
|
| if HAS_PLT: |
| if N_CLASSES <= 10: |
| CLASS_COLORS = [ |
| '#e6194b', '#3cb44b', '#4363d8', '#f58231', '#911eb4', |
| '#42d4f4', '#f032e6', '#bfef45', '#469990', '#dcbeff'] |
| else: |
| |
| import colorsys |
| CLASS_COLORS = [] |
| for i in range(N_CLASSES): |
| |
| hue = (i * 0.618033988749895) % 1.0 |
| |
| sat = 0.75 + 0.25 * (i % 3) / 2 |
| val = 0.85 + 0.15 * ((i + 1) % 2) |
| r, g, b = colorsys.hsv_to_rgb(hue, sat, val) |
| CLASS_COLORS.append(f'#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}') |
|
|
| |
| plt.style.use('dark_background') |
| plt.rcParams.update({ |
| 'figure.facecolor': '#1a1a2e', |
| 'axes.facecolor': '#16213e', |
| 'axes.edgecolor': '#444466', |
| 'axes.labelcolor': '#e0e0e0', |
| 'text.color': '#e0e0e0', |
| 'xtick.color': '#aaaacc', |
| 'ytick.color': '#aaaacc', |
| 'grid.color': '#333355', |
| 'legend.facecolor': '#1a1a2e', |
| 'legend.edgecolor': '#444466', |
| }) |
|
|
| print(f"\n{'='*70}") |
| print("VISUALIZATIONS") |
| print(f"{'='*70}") |
|
|
| def save_fig(filename, dpi=200): |
| plt.savefig(f'{OUT_DIR}/{filename}', dpi=dpi) |
|
|
| |
| def draw_sphere_grid_2d(ax, radius, n_meridians=24): |
| """Draw sphere reference grid β UNMISSABLE.""" |
| print(f" >>> DRAWING 2D GRID: radius={radius:.4f}, lw=5, white+cyan") |
| theta = np.linspace(0, 2 * np.pi, 500) |
| xr = radius * np.cos(theta) |
| yr = radius * np.sin(theta) |
|
|
| |
| ax.plot(xr, yr, color='#00e5ff', alpha=0.6, lw=9, zorder=49) |
| |
| ax.plot(xr, yr, color='white', alpha=1.0, lw=5, zorder=50, |
| solid_capstyle='round') |
|
|
| |
| for frac in [0.5, 0.75]: |
| ax.plot(frac * xr, frac * yr, |
| color='#00e5ff', alpha=0.5, lw=2, linestyle='--', zorder=50) |
|
|
| |
| for i in range(n_meridians): |
| a = 2 * np.pi * i / n_meridians |
| r0, r1 = radius * 0.92, radius * 1.08 |
| ax.plot([r0*np.cos(a), r1*np.cos(a)], |
| [r0*np.sin(a), r1*np.sin(a)], |
| color='white', alpha=0.8, lw=2, zorder=50) |
|
|
| |
| s = radius * 1.15 |
| ax.plot([-s, s], [0, 0], color='#00e5ff', alpha=0.3, lw=1.5, zorder=49) |
| ax.plot([0, 0], [-s, s], color='#00e5ff', alpha=0.3, lw=1.5, zorder=49) |
|
|
| |
| ax.text(radius * 0.72, radius * 0.72, f'r={radius:.2f}', |
| color='#00e5ff', fontsize=10, fontweight='bold', |
| alpha=0.9, zorder=51) |
|
|
| def draw_sphere_grid_3d(ax, radius, n_lines=16): |
| """Draw a wireframe sphere in 3D PCA space β THICK.""" |
| print(f" >>> DRAWING 3D WIREFRAME: radius={radius:.4f}, lw=1.2+3") |
| theta = np.linspace(0, 2 * np.pi, 80) |
| phi = np.linspace(0, np.pi, 40) |
|
|
| |
| for p in np.linspace(0, np.pi, n_lines + 1)[1:-1]: |
| r = radius * np.sin(p) |
| z = radius * np.cos(p) |
| ax.plot(r * np.cos(theta), r * np.sin(theta), |
| z * np.ones_like(theta), |
| color='white', alpha=0.4, lw=1.2) |
|
|
| |
| for t in np.linspace(0, 2 * np.pi, n_lines, endpoint=False): |
| x = radius * np.sin(phi) * np.cos(t) |
| y = radius * np.sin(phi) * np.sin(t) |
| z = radius * np.cos(phi) |
| ax.plot(x, y, z, color='white', alpha=0.4, lw=1.2) |
|
|
| |
| ax.plot(radius * np.cos(theta), radius * np.sin(theta), |
| np.zeros_like(theta), color='#00e5ff', alpha=0.9, lw=3) |
|
|
| |
| embs_c = embs_n[:5000] - embs_n[:5000].mean(0, keepdim=True) |
| _, _, Vt = torch.linalg.svd(embs_c, full_matrices=False) |
| proj_2d = (embs_n @ Vt[:2].T).numpy() |
| proj_3d = (embs_n @ Vt[:3].T).numpy() |
| anch_2d = (anchors_n @ Vt[:2].T).numpy() |
| anch_3d = (anchors_n @ Vt[:3].T).numpy() |
| proj_labels = labels.numpy() |
|
|
| |
| emb_radii_2d = np.sqrt(proj_2d[:5000, 0]**2 + proj_2d[:5000, 1]**2) |
| sphere_r_2d = np.percentile(emb_radii_2d, 95) |
|
|
| emb_radii_3d = np.sqrt((proj_3d[:3000]**2).sum(axis=1)) |
| sphere_r_3d = np.percentile(emb_radii_3d, 95) |
|
|
| |
| data_range_2d = max(np.abs(proj_2d[:5000]).max(), np.abs(anch_2d).max()) |
| data_range_3d = max(np.abs(proj_3d[:3000]).max(), np.abs(anch_3d).max()) |
| if sphere_r_2d < 0.01: |
| sphere_r_2d = data_range_2d * 0.9 |
| if sphere_r_3d < 0.01: |
| sphere_r_3d = data_range_3d * 0.9 |
|
|
| print(f" Sphere radius (2D): {sphere_r_2d:.4f} (3D): {sphere_r_3d:.4f}") |
| print(f" Data range (2D): {data_range_2d:.4f} (3D): {data_range_3d:.4f}") |
|
|
| |
| print(" [1/8] PCA projection...") |
| fig, ax = plt.subplots(1, 1, figsize=(12, 10)) |
| for c in range(N_CLASSES): |
| mask = proj_labels[:5000] == c |
| if mask.sum() == 0: continue |
| lbl = CLASS_NAMES[c] if N_CLASSES <= 20 else None |
| ax.scatter(proj_2d[:5000][mask, 0], proj_2d[:5000][mask, 1], |
| c=CLASS_COLORS[c], s=4, alpha=0.5, label=lbl, zorder=2) |
| ax.scatter(anch_2d[:, 0], anch_2d[:, 1], |
| c='#FFD700', s=60, marker='*', edgecolors='white', linewidths=0.3, zorder=5, label='anchors') |
| |
| draw_sphere_grid_2d(ax, sphere_r_2d) |
| if N_CLASSES <= 20: |
| ax.legend(fontsize=7, markerscale=2, loc='upper right', ncol=2) |
| ax.set_title(f'GeoLIP Core β PCA Embedding Space ({ds_name})\n' |
| f'val={val_acc:.1f}% | {total_params:,} params | ' |
| f'CV={v_cv:.4f} | {n_active}/{n_anchors} anchors', fontsize=11) |
| ax.set_xlabel('PC1'); ax.set_ylabel('PC2') |
| ax.set_aspect('equal') |
| ax.grid(True, alpha=0.15, color='#555577') |
| plt.tight_layout() |
| save_fig('01_pca_embedding_space.png') |
| plt.close() |
|
|
| |
| print(" [2/8] Triangulation connections...") |
| fig, ax = plt.subplots(1, 1, figsize=(12, 10)) |
| subset = min(500, len(embs)) |
| for i in range(subset): |
| a_idx = nearest[i].item() |
| ax.plot([proj_2d[i, 0], anch_2d[a_idx, 0]], |
| [proj_2d[i, 1], anch_2d[a_idx, 1]], |
| c=CLASS_COLORS[labels[i].item()], alpha=0.1, linewidth=0.5) |
| for c in range(N_CLASSES): |
| mask = proj_labels[:5000] == c |
| if mask.sum() == 0: continue |
| ax.scatter(proj_2d[:5000][mask, 0], proj_2d[:5000][mask, 1], |
| c=CLASS_COLORS[c], s=5, alpha=0.4, zorder=2) |
| ax.scatter(anch_2d[:, 0], anch_2d[:, 1], |
| c='#FFD700', s=80, marker='*', edgecolors='white', linewidths=0.3, zorder=5) |
| if n_anchors <= 128: |
| for a in range(n_anchors): |
| a_mask = nearest == a |
| if a_mask.sum() > 0: |
| dom_class = labels[a_mask].mode().values.item() |
| ax.annotate(str(dom_class), (anch_2d[a, 0], anch_2d[a, 1]), |
| fontsize=4, ha='center', va='center', |
| color='white', fontweight='bold', |
| bbox=dict(boxstyle='round,pad=0.1', |
| fc=CLASS_COLORS[dom_class], |
| ec='#FFD700', linewidth=0.5, |
| alpha=0.85)) |
| |
| draw_sphere_grid_2d(ax, sphere_r_2d) |
| ax.set_title(f'Triangulation: Image β Nearest Anchor ({ds_name})', fontsize=11) |
| ax.set_aspect('equal') |
| ax.grid(True, alpha=0.15, color='#555577') |
| plt.tight_layout() |
| save_fig('02_triangulation_connections.png') |
| plt.close() |
|
|
| |
| print(" [3/8] 3D sphere projection...") |
| fig = plt.figure(figsize=(12, 10)) |
| ax = fig.add_subplot(111, projection='3d') |
| n_3d = min(3000, len(embs)) |
| for c in range(min(N_CLASSES, 20)): |
| mask = proj_labels[:n_3d] == c |
| if mask.sum() == 0: continue |
| ax.scatter(proj_3d[:n_3d][mask, 0], proj_3d[:n_3d][mask, 1], |
| proj_3d[:n_3d][mask, 2], |
| c=CLASS_COLORS[c], s=5, alpha=0.4, |
| label=CLASS_NAMES[c] if N_CLASSES <= 20 else None) |
| ax.scatter(anch_3d[:, 0], anch_3d[:, 1], anch_3d[:, 2], |
| c='#FFD700', s=40, marker='*', edgecolors='white', linewidths=0.3, zorder=5) |
| |
| draw_sphere_grid_3d(ax, sphere_r_3d) |
| if N_CLASSES <= 20: |
| ax.legend(fontsize=6, markerscale=2, loc='upper left', ncol=2) |
| ax.set_title(f'3D PCA β Constellation on the Sphere\n' |
| f'{n_anchors} anchors, {N_CLASSES} classes', fontsize=11) |
| try: |
| ax.set_box_aspect([1, 1, 1]) |
| except AttributeError: |
| pass |
| ax.xaxis.pane.fill = False |
| ax.yaxis.pane.fill = False |
| ax.zaxis.pane.fill = False |
| plt.tight_layout() |
| save_fig('03_3d_sphere.png') |
| plt.close() |
|
|
| |
| print(" [4/8] Anchor-class assignment matrix...") |
| assign_mat = torch.zeros(N_CLASSES, n_anchors) |
| for c in range(N_CLASSES): |
| c_nearest = nearest[labels == c] |
| for a in range(n_anchors): |
| assign_mat[c, a] = (c_nearest == a).sum().float() |
| assign_norm = assign_mat / (assign_mat.sum(dim=1, keepdim=True) + 1e-8) |
|
|
| peak_class = assign_norm.argmax(dim=0) |
| sort_order = peak_class.argsort() |
| assign_sorted = assign_norm[:, sort_order] |
|
|
| h = max(6, N_CLASSES * 0.12) |
| fig, ax = plt.subplots(1, 1, figsize=(16, h)) |
| im = ax.imshow(assign_sorted.numpy(), aspect='auto', cmap='inferno') |
| if N_CLASSES <= 30: |
| ax.set_yticks(range(N_CLASSES)) |
| ax.set_yticklabels(CLASS_NAMES, fontsize=max(4, 9 - N_CLASSES // 15)) |
| ax.set_xlabel('Anchor index (sorted by peak class)') |
| ax.set_title(f'Class β Anchor Assignment ({ds_name})', fontsize=11) |
| plt.colorbar(im, ax=ax, shrink=0.8) |
| plt.tight_layout() |
| save_fig('04_anchor_class_heatmap.png') |
| plt.close() |
|
|
| |
| print(" [5/8] Class triangulation profiles...") |
| if N_CLASSES <= 10: |
| show_classes = list(range(N_CLASSES)) |
| else: |
| sorted_by_acc = sorted(range(N_CLASSES), key=lambda c: class_accs[c]) |
| show_classes = sorted_by_acc[:5] + sorted_by_acc[-5:] |
|
|
| nrows, ncols = 2, 5 |
| fig, axes = plt.subplots(nrows, ncols, figsize=(20, 8)) |
| for idx, c in enumerate(show_classes): |
| ax = axes[idx // ncols][idx % ncols] |
| c_tris = tris[labels == c] |
| if len(c_tris) == 0: continue |
| mean_tri = c_tris.mean(0).numpy() |
| std_tri = c_tris.std(0).numpy() |
| x = np.arange(n_anchors) |
| color = CLASS_COLORS[c] |
| ax.fill_between(x, mean_tri - std_tri, mean_tri + std_tri, |
| alpha=0.3, color=color) |
| ax.plot(x, mean_tri, color=color, linewidth=1.5) |
| ax.set_title(f'{CLASS_NAMES[c]} ({class_accs[c]:.0f}%)', |
| fontsize=9, fontweight='bold', color=color) |
| ax.set_ylim(0, max(1.6, mean_tri.max() * 1.2)) |
| ax.tick_params(labelsize=5) |
| tag = "all classes" if N_CLASSES <= 10 else "5 worst + 5 best" |
| plt.suptitle(f'Triangulation Fingerprints ({tag})', fontsize=12) |
| plt.tight_layout() |
| save_fig('05_triangulation_profiles.png') |
| plt.close() |
|
|
| |
| print(" [6/8] Anchor utilization...") |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) |
|
|
| sorted_counts, _ = counts.sort(descending=True) |
| ax1.bar(range(n_anchors), sorted_counts.numpy(), |
| color=['#00BCD4' if c > 0 else '#FF5252' for c in sorted_counts], width=1.0) |
| ax1.set_xlabel('Anchor (sorted)') |
| ax1.set_ylabel('Assigned samples') |
| ax1.set_title(f'Anchor Utilization ({n_active}/{n_anchors} active)') |
| ax1.axhline(y=len(labels) / n_anchors, color='#888899', linestyle='--', alpha=0.5) |
|
|
| |
| entropies = [] |
| for c in range(N_CLASSES): |
| c_nearest = nearest[labels == c] |
| dist = torch.zeros(n_anchors) |
| for a in range(n_anchors): |
| dist[a] = (c_nearest == a).sum().float() |
| dist = dist / (dist.sum() + 1e-8) |
| ent = -(dist * (dist + 1e-10).log()).sum().item() |
| entropies.append(ent) |
|
|
| if N_CLASSES <= 20: |
| ax2.barh(range(N_CLASSES), entropies, |
| color=[CLASS_COLORS[c] for c in range(N_CLASSES)]) |
| ax2.set_yticks(range(N_CLASSES)) |
| ax2.set_yticklabels(CLASS_NAMES, fontsize=8) |
| ax2.set_xlabel('Anchor assignment entropy') |
| else: |
| ax2.hist(entropies, bins=30, color='#00BCD4', edgecolor='#333355') |
| ax2.set_xlabel('Anchor assignment entropy') |
| ax2.set_ylabel('Number of classes') |
|
|
| |
| c_sorted = counts.float().sort().values |
| cum = c_sorted.cumsum(0) |
| gini = (1 - 2 * cum.sum() / (len(c_sorted) * c_sorted.sum() + 1e-8)).item() |
| ax2.set_title(f'Anchor Spread (Gini={gini:.3f})') |
| plt.tight_layout() |
| save_fig('06_anchor_utilization.png') |
| plt.close() |
|
|
| |
| print(" [7/8] Patchwork compartment responses...") |
| n_comp = cfg.get('n_comp', 8) |
| asgn = model.patchwork.asgn.cpu() |
|
|
| if N_CLASSES <= 10: |
| show_c = list(range(N_CLASSES)) |
| else: |
| show_c = show_classes |
|
|
| ncols_pw = min(4, n_comp) |
| nrows_pw = math.ceil(n_comp / ncols_pw) |
| fig, axes = plt.subplots(nrows_pw, ncols_pw, figsize=(4 * ncols_pw, 3 * nrows_pw)) |
| if n_comp == 1: axes = [[axes]] |
| elif nrows_pw == 1: axes = [axes if isinstance(axes, list) else list(axes)] |
| elif ncols_pw == 1: axes = [[a] for a in axes] |
| axes_flat = [axes[r][c] for r in range(nrows_pw) for c in range(ncols_pw)] |
|
|
| for k in range(min(n_comp, len(axes_flat))): |
| ax = axes_flat[k] |
| comp_tris = tris[:, asgn == k] |
| class_means = [] |
| class_labels_show = [] |
| for c in show_c: |
| cm = comp_tris[labels == c] |
| if len(cm) > 0: |
| class_means.append(cm.mean(0).numpy()) |
| class_labels_show.append(CLASS_NAMES[c]) |
| if not class_means: continue |
| class_means = np.stack(class_means) |
| ax.imshow(class_means, aspect='auto', cmap='plasma') |
| ax.set_yticks(range(len(class_labels_show))) |
| ax.set_yticklabels(class_labels_show, fontsize=6) |
| ax.set_title(f'Comp {k}', fontsize=9) |
| for k in range(n_comp, len(axes_flat)): |
| axes_flat[k].set_visible(False) |
| plt.suptitle('Patchwork Compartment Responses by Class', fontsize=12) |
| plt.tight_layout() |
| save_fig('07_patchwork_compartments.png') |
| plt.close() |
|
|
| |
| print(" [8/8] Confusion matrix...") |
| conf_mat = torch.zeros(N_CLASSES, N_CLASSES, dtype=torch.long) |
| for i in range(len(labels)): |
| conf_mat[labels[i], preds[i]] += 1 |
| conf_pct = conf_mat.float() / (conf_mat.sum(dim=1, keepdim=True) + 1e-8) * 100 |
|
|
| if N_CLASSES <= 20: |
| fig, ax = plt.subplots(1, 1, figsize=(8, 7)) |
| im = ax.imshow(conf_pct.numpy(), cmap='magma', vmin=0, vmax=100) |
| for i in range(N_CLASSES): |
| for j in range(N_CLASSES): |
| v = conf_pct[i, j].item() |
| ax.text(j, i, f'{v:.0f}', ha='center', va='center', |
| fontsize=max(4, 8 - N_CLASSES // 5), |
| color='black' if v > 60 else '#e0e0e0') |
| ax.set_xticks(range(N_CLASSES)) |
| ax.set_yticks(range(N_CLASSES)) |
| ax.set_xticklabels(CLASS_NAMES, rotation=45, ha='right', fontsize=7) |
| ax.set_yticklabels(CLASS_NAMES, fontsize=7) |
| else: |
| fig, ax = plt.subplots(1, 1, figsize=(14, 12)) |
| im = ax.imshow(conf_pct.numpy(), cmap='magma', vmin=0, vmax=100) |
| ax.set_xlabel('Predicted class') |
| ax.set_ylabel('True class') |
| ax.set_title(f'Confusion Matrix β {val_acc:.1f}% ({ds_name})', fontsize=11) |
| plt.colorbar(im, ax=ax, shrink=0.8) |
| plt.tight_layout() |
| save_fig('08_confusion_matrix.png') |
| plt.close() |
|
|
| print(f"\n β All 8 visualizations saved to {OUT_DIR}/") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*70}") |
| print("SUMMARY") |
| print(f"{'='*70}") |
| print(f" Dataset: {ds_name} ({N_CLASSES} classes)") |
| print(f" Params: {total_params:,}") |
| print(f" Val accuracy: {val_acc:.1f}%") |
| print(f" Eff dim: {eff_dim:.1f}/{embs.shape[1]}") |
| print(f" Same-class cos: {same_cos:.4f}") |
| print(f" Diff-class cos: {diff_cos:.4f}") |
| print(f" Gap: {gap:.4f}") |
| print(f" CV: {v_cv:.4f}") |
| print(f" Anchors active: {n_active}/{n_anchors}") |
|
|
| worst_i = min(range(N_CLASSES), key=lambda c: class_accs[c]) |
| best_i = max(range(N_CLASSES), key=lambda c: class_accs[c]) |
| print(f" Worst class: {CLASS_NAMES[worst_i]} ({class_accs[worst_i]:.1f}%)") |
| print(f" Best class: {CLASS_NAMES[best_i]} ({class_accs[best_i]:.1f}%)") |
|
|
| warnings = [] |
| if n_active < n_anchors * 0.5: |
| warnings.append(f"Anchor collapse: {n_active}/{n_anchors}") |
| if eff_dim < 5: |
| warnings.append(f"Embedding collapse: eff_dim={eff_dim:.1f}") |
| if gap < 0.02: |
| warnings.append(f"Low class separation: gap={gap:.4f}") |
|
|
| if warnings: |
| print(f"\n β WARNINGS: {', '.join(warnings)}") |
| else: |
| print(f"\n β All diagnostics healthy") |
|
|
| print(f"\n{'='*70}") |
| print("ANALYSIS COMPLETE") |
| print(f"{'='*70}") |
|
|
| |
| |
| |
|
|
| if HF_PUSH: |
| from huggingface_hub import upload_folder |
| print(f"\n Uploading {OUT_DIR}/ β {HF_REPO_ID}/analysis/ ...") |
| upload_folder( |
| repo_id=HF_REPO_ID, |
| folder_path=OUT_DIR, |
| path_in_repo="analysis", |
| commit_message=f"Analysis: val={val_acc:.1f}% CV={v_cv:.4f} {n_active}/{n_anchors} anchors", |
| ) |
| print(f" β Done: https://huggingface.co/{HF_REPO_ID}/tree/main/analysis") |