geolip-constellation-core / trainer_model.py
AbstractPhil's picture
Update trainer_model.py
e255399 verified
#!/usr/bin/env python3
"""
GeoLIP Core β€” Back to Basics
==============================
Conv encoder β†’ sphere β†’ constellation β†’ patchwork β†’ classifier.
No streams. No GAL. No Procrustes. No mastery queue.
Just the geometric classification pipeline.
Two augmented views β†’ InfoNCE + CE + CV.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os, time
import numpy as np
from itertools import combinations
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# ══════════════════════════════════════════════════════════════════
# UNIFORM HYPERSPHERE INIT
# ══════════════════════════════════════════════════════════════════
def uniform_hypersphere_init(n, d):
if n <= d:
M = torch.randn(d, n)
Q, _ = torch.linalg.qr(M)
return Q.T.contiguous()
else:
M = torch.randn(d, d)
Q, _ = torch.linalg.qr(M)
basis = Q.T
extra = F.normalize(torch.randn(n - d, d), dim=-1)
vecs = torch.cat([basis, extra], dim=0)
for _ in range(200):
sim = vecs @ vecs.T
sim.fill_diagonal_(-2.0)
nn_idx = sim.argmax(dim=1)
vecs = F.normalize(vecs - 0.05 * vecs[nn_idx], dim=-1)
return vecs
# ══════════════════════════════════════════════════════════════════
# CONSTELLATION + PATCHWORK
# ══════════════════════════════════════════════════════════════════
class Constellation(nn.Module):
def __init__(self, n_anchors, dim, anchor_drop=0.0):
super().__init__()
self.anchors = nn.Parameter(uniform_hypersphere_init(n_anchors, dim))
self.anchor_drop = anchor_drop
def triangulate(self, emb, training=False):
anchors = F.normalize(self.anchors, dim=-1)
if training and self.anchor_drop > 0:
mask = torch.rand(anchors.shape[0], device=anchors.device) > self.anchor_drop
if mask.sum() < 2: mask[:2] = True
anchors = anchors[mask]
cos = emb @ anchors.T
tri = 1.0 - cos
_, nearest_local = cos.max(dim=-1)
nearest = mask.nonzero(as_tuple=True)[0][nearest_local]
else:
cos = emb @ anchors.T
tri = 1.0 - cos
_, nearest = cos.max(dim=-1)
return tri, nearest
class Patchwork(nn.Module):
def __init__(self, n_anchors, n_comp, d_comp):
super().__init__()
self.n_comp = n_comp
self.register_buffer('asgn', torch.arange(n_anchors) % n_comp)
anchors_per = n_anchors // n_comp
self.comps = nn.ModuleList([nn.Sequential(
nn.Linear(anchors_per, d_comp * 2), nn.GELU(),
nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp))
for _ in range(n_comp)])
def forward(self, tri):
return torch.cat([self.comps[k](tri[:, self.asgn == k])
for k in range(self.n_comp)], -1)
# ══════════════════════════════════════════════════════════════════
# CONV ENCODER
# ══════════════════════════════════════════════════════════════════
class ConvEncoder(nn.Module):
"""
Simple conv backbone. No attention, no geometric layers.
Just feature extraction into a flat vector.
"""
def __init__(self, output_dim=128):
super().__init__()
self.features = nn.Sequential(
# 32Γ—32 β†’ 16Γ—16
nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(),
nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(),
nn.MaxPool2d(2),
# 16Γ—16 β†’ 8Γ—8
nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(),
nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(),
nn.MaxPool2d(2),
# 8Γ—8 β†’ 4Γ—4
nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(),
nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(),
nn.MaxPool2d(2),
# 4Γ—4 β†’ global
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
)
self.proj = nn.Sequential(
nn.Linear(256, output_dim),
nn.LayerNorm(output_dim),
)
def forward(self, x):
return self.proj(self.features(x))
# ══════════════════════════════════════════════════════════════════
# GEOLIP CORE
# ══════════════════════════════════════════════════════════════════
class GeoLIPCore(nn.Module):
def __init__(
self,
num_classes=10,
output_dim=128,
n_anchors=64,
n_comp=8,
d_comp=64,
anchor_drop=0.15,
cv_target=0.22,
infonce_temp=0.07,
):
super().__init__()
self.num_classes = num_classes
self.output_dim = output_dim
self.cv_target = cv_target
self.infonce_temp = infonce_temp
self.config = {k: v for k, v in locals().items()
if k != 'self' and not k.startswith('_')}
self.encoder = ConvEncoder(output_dim)
self.constellation = Constellation(n_anchors, output_dim, anchor_drop)
self.patchwork = Patchwork(n_anchors, n_comp, d_comp)
pw_dim = n_comp * d_comp
self.classifier = nn.Sequential(
nn.Linear(pw_dim + output_dim, pw_dim), nn.GELU(),
nn.LayerNorm(pw_dim), nn.Dropout(0.1),
nn.Linear(pw_dim, num_classes))
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.LayerNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
feat = self.encoder(x)
emb = F.normalize(feat, dim=-1)
# Full tri for patchwork (needs all anchor columns)
tri, nearest = self.constellation.triangulate(emb, training=False)
pw = self.patchwork(tri)
# Dropout version for nearest tracking only
if self.training:
_, nearest = self.constellation.triangulate(emb, training=True)
logits = self.classifier(torch.cat([pw, emb], dim=-1))
return {
'logits': logits,
'embedding': emb,
'triangulation': tri,
'nearest': nearest,
}
def compute_loss(self, output, targets, output_aug=None):
ld = {}
emb = output['embedding']
B = emb.shape[0]
# CE
l_ce = F.cross_entropy(output['logits'], targets)
ld['ce'] = l_ce
ld['acc'] = (output['logits'].argmax(-1) == targets).float().mean().item()
# InfoNCE
if output_aug is not None:
emb_aug = output_aug['embedding']
labels_nce = torch.arange(B, device=emb.device)
sim = emb @ emb_aug.T / self.infonce_temp
l_nce = F.cross_entropy(sim, labels_nce)
nce_acc = (sim.argmax(1) == labels_nce).float().mean().item()
ld['nce'] = l_nce
ld['nce_acc'] = nce_acc
# ── Anchor attraction: pull each embedding toward its nearest anchor ──
anchors_n = F.normalize(self.constellation.anchors, dim=-1)
cos_to_anchors = emb @ anchors_n.T # (B, n_anchors)
nearest_cos = cos_to_anchors.max(dim=1).values # (B,)
l_attract = (1.0 - nearest_cos).mean() # 0 when on top of anchor
ld['attract'] = l_attract
ld['nearest_cos'] = nearest_cos.mean().item()
# CV
l_cv = self._cv_loss(emb)
ld['cv'] = l_cv
# Anchor spread
sim_a = anchors_n @ anchors_n.T
mask = ~torch.eye(anchors_n.shape[0], dtype=torch.bool, device=anchors_n.device)
l_spread = F.relu(sim_a[mask]).mean()
ld['spread'] = l_spread
# Total
loss = (l_ce
+ ld.get('nce', 0.0) * 1.0
+ l_attract * 0.5
+ l_cv * 0.01
+ l_spread * 0.001)
ld['total'] = loss
return loss, ld
@torch.no_grad()
def push_anchors_to_centroids(self, emb_buffer, label_buffer, lr=0.1):
"""
Push anchors toward CLASS centroids, not nearest-anchor centroids.
Phase 1: Compute class centroids from labels
Phase 2: Each class owns (n_anchors / n_classes) anchors
Phase 3: Assigned anchors blend toward their class centroid
with small angular offsets so they don't all collapse
This works even when anchors start bunched at origin.
"""
anchors = self.constellation.anchors.data # (A, D)
n_a = anchors.shape[0]
emb_n = F.normalize(emb_buffer, dim=-1)
device = anchors.device
# Phase 1: class centroids
classes = label_buffer.unique()
n_cls = classes.shape[0]
centroids = []
for c in classes:
mask = label_buffer == c
if mask.sum() > 0:
centroids.append(F.normalize(emb_n[mask].mean(0, keepdim=True), dim=-1))
if len(centroids) == 0:
return 0
centroids = torch.cat(centroids, dim=0) # (C, D)
# Phase 2: assign anchors to classes round-robin
# Sort anchors by cosine to each centroid, greedily assign
anchors_n = F.normalize(anchors, dim=-1)
cos = anchors_n @ centroids.T # (A, C)
anchors_per_class = n_a // n_cls
assigned_class = torch.full((n_a,), -1, dtype=torch.long, device=device)
class_count = torch.zeros(n_cls, dtype=torch.long, device=device)
# Greedy: for each anchor, assign to its best class if that class has room
_, flat_idx = cos.flatten().sort(descending=True)
for idx in flat_idx:
a = (idx // n_cls).item()
c = (idx % n_cls).item()
if assigned_class[a] >= 0:
continue
if class_count[c] >= anchors_per_class + 1: # +1 for remainder
continue
assigned_class[a] = c
class_count[c] += 1
if (assigned_class >= 0).all():
break
# Unassigned leftovers β†’ nearest centroid
unassigned = (assigned_class < 0).nonzero(as_tuple=True)[0]
if len(unassigned) > 0:
leftover_cos = anchors_n[unassigned] @ centroids.T
assigned_class[unassigned] = leftover_cos.argmax(dim=1)
# Phase 3: push each anchor toward its class centroid
moved = 0
for a in range(n_a):
c = assigned_class[a].item()
target = centroids[c]
# Add small angular offset so co-class anchors don't collapse
rank_in_class = (assigned_class[:a] == c).sum().item()
if anchors_per_class > 1 and rank_in_class > 0:
# Tiny perpendicular perturbation
noise = torch.randn_like(target) * 0.05
noise = noise - (noise * target).sum() * target # project out radial
target = F.normalize((target + noise).unsqueeze(0), dim=-1).squeeze(0)
anchors[a] = F.normalize(
(anchors_n[a] + lr * (target - anchors_n[a])).unsqueeze(0),
dim=-1).squeeze(0)
moved += 1
return moved
def _cv_loss(self, emb, n_samples=64, n_points=5):
B = emb.shape[0]
if B < n_points: return torch.tensor(0.0, device=emb.device)
vols = []
for _ in range(n_samples):
idx = torch.randperm(min(B, 512), device=emb.device)[:n_points]
pts = emb[idx].unsqueeze(0)
gram = torch.bmm(pts, pts.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
d2 = F.relu(d2)
N = n_points
cm = torch.zeros(1, N+1, N+1, device=emb.device, dtype=emb.dtype)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
k = N - 1
pf = ((-1.0)**(k+1)) / ((2.0**k) * (math.factorial(k)**2))
v2 = pf * torch.linalg.det(cm.float())
if v2[0].item() > 1e-20:
vols.append(v2[0].to(emb.dtype).sqrt())
if len(vols) < 5:
return torch.tensor(0.0, device=emb.device)
vt = torch.stack(vols)
cv = vt.std() / (vt.mean() + 1e-8)
return (cv - self.cv_target).pow(2)
# ══════════════════════════════════════════════════════════════════
# DATA
# ══════════════════════════════════════════════════════════════════
CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR_STD = (0.2470, 0.2435, 0.2616)
class TwoViewDataset(torch.utils.data.Dataset):
def __init__(self, base_ds, transform):
self.base = base_ds; self.transform = transform
def __len__(self): return len(self.base)
def __getitem__(self, i):
img, label = self.base[i]
return self.transform(img), self.transform(img), label
# ══════════════════════════════════════════════════════════════════
# TRAINING
# ══════════════════════════════════════════════════════════════════
# Config
NUM_CLASSES = 10
OUTPUT_DIM = 128
N_ANCHORS = 64
N_COMP = 8
D_COMP = 64
BATCH = 256
EPOCHS = 100
LR = 3e-4
print("=" * 60)
print("GeoLIP Core β€” Conv + Constellation + Patchwork")
print(f" Encoder: 6-layer conv β†’ {OUTPUT_DIM}-d sphere")
print(f" Constellation: {N_ANCHORS} anchors, {N_COMP}Γ—{D_COMP} patchwork")
print(f" Loss: CE + InfoNCE + CV(0.22)")
print(f" Batch: {BATCH}, LR: {LR}, Epochs: {EPOCHS}")
print(f" Device: {DEVICE}")
print("=" * 60)
aug_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.2, 0.2, 0.2, 0.05),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
val_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
raw_train = datasets.CIFAR10(root='./data', train=True, download=True)
train_ds = TwoViewDataset(raw_train, aug_transform)
val_ds = datasets.CIFAR10(root='./data', train=False,
download=True, transform=val_transform)
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size=BATCH, shuffle=True,
num_workers=2, pin_memory=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(
val_ds, batch_size=BATCH, shuffle=False,
num_workers=2, pin_memory=True)
print(f" Train: {len(train_ds):,} Val: {len(val_ds):,}")
# Build
model = GeoLIPCore(
num_classes=NUM_CLASSES, output_dim=OUTPUT_DIM,
n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP,
).to(DEVICE)
n_params = sum(p.numel() for p in model.parameters())
print(f" Parameters: {n_params:,}")
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
total_steps = len(train_loader) * EPOCHS
warmup_steps = len(train_loader) * 3
scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
[torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=0.01, total_iters=warmup_steps),
torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=max(total_steps - warmup_steps, 1), eta_min=1e-6)],
milestones=[warmup_steps])
scaler = torch.amp.GradScaler("cuda")
os.makedirs("checkpoints", exist_ok=True)
writer = SummaryWriter("runs/geolip_core")
best_acc = 0.0
gs = 0
# Anchor push config
PUSH_INTERVAL = 50 # batches between centroid pushes
PUSH_LR = 0.1 # blend rate toward centroid
PUSH_BUFFER_SIZE = 5000
emb_buffer = None # (N, D) accumulated embeddings
lbl_buffer = None # (N,) accumulated labels
push_count = 0
print(f"\n{'='*60}")
print(f"TRAINING β€” {EPOCHS} epochs")
print(f" Anchor push: every {PUSH_INTERVAL} batches, lr={PUSH_LR}")
print(f"{'='*60}")
for epoch in range(EPOCHS):
model.train()
t0 = time.time()
tot_loss, tot_ce, tot_nce, tot_cv = 0, 0, 0, 0
tot_acc, tot_nce_acc, tot_nearest_cos, n = 0, 0, 0, 0
correct, total = 0, 0
pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b")
for v1, v2, targets in pbar:
v1 = v1.to(DEVICE, non_blocking=True)
v2 = v2.to(DEVICE, non_blocking=True)
targets = targets.to(DEVICE, non_blocking=True)
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
out1 = model(v1)
out2 = model(v2)
loss, ld = model.compute_loss(out1, targets, output_aug=out2)
optimizer.zero_grad(set_to_none=True)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer); scaler.update()
scheduler.step()
gs += 1
# ── Accumulate embeddings for anchor push ──
with torch.no_grad():
batch_emb = out1['embedding'].detach().float()
if emb_buffer is None:
emb_buffer = batch_emb
lbl_buffer = targets.detach()
else:
emb_buffer = torch.cat([emb_buffer, batch_emb])[-PUSH_BUFFER_SIZE:]
lbl_buffer = torch.cat([lbl_buffer, targets.detach()])[-PUSH_BUFFER_SIZE:]
# ── Periodic anchor push toward class centroids ──
if gs % PUSH_INTERVAL == 0 and emb_buffer is not None and emb_buffer.shape[0] > 500:
moved = model.push_anchors_to_centroids(
emb_buffer, lbl_buffer, lr=PUSH_LR)
push_count += 1
writer.add_scalar("step/anchors_moved", moved, gs)
preds = out1['logits'].argmax(-1)
correct += (preds == targets).sum().item()
total += targets.shape[0]
tot_loss += loss.item()
tot_nce_acc += ld.get('nce_acc', 0)
tot_nearest_cos += ld.get('nearest_cos', 0)
n += 1
if n % 10 == 0:
pbar.set_postfix(
loss=f"{tot_loss/n:.4f}",
acc=f"{100*correct/total:.0f}%",
nce=f"{tot_nce_acc/n:.2f}",
cos=f"{ld.get('nearest_cos', 0):.3f}",
push=push_count,
ordered=True)
elapsed = time.time() - t0
train_acc = 100 * correct / total
# Val
model.eval()
vc, vt_n, vl = 0, 0, 0
all_embs = []
with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
for imgs, lbls in val_loader:
imgs = imgs.to(DEVICE)
lbls = lbls.to(DEVICE)
out = model(imgs)
vc += (out['logits'].argmax(-1) == lbls).sum().item()
vt_n += lbls.shape[0]
vl += F.cross_entropy(out['logits'], lbls).item()
all_embs.append(out['embedding'].float().cpu())
val_acc = 100 * vc / vt_n
# CV
embs = torch.cat(all_embs)[:2000].to(DEVICE)
with torch.no_grad():
vols = []
for _ in range(200):
idx = torch.randperm(2000)[:5]
pts = embs[idx].unsqueeze(0).float()
gram = torch.bmm(pts, pts.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
d2 = F.relu(d2)
cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
v2 = -torch.linalg.det(cm) / 9216
if v2[0].item() > 1e-20:
vols.append(v2[0].sqrt())
v_cv = (torch.stack(vols).std() / (torch.stack(vols).mean() + 1e-8)).item() if len(vols) > 10 else 0
# Anchors
with torch.no_grad():
_, vnp = model.constellation.triangulate(embs, training=False)
n_active = vnp.cpu().unique().numel()
writer.add_scalar("epoch/train_acc", train_acc, epoch+1)
writer.add_scalar("epoch/val_acc", val_acc, epoch+1)
writer.add_scalar("epoch/val_cv", v_cv, epoch+1)
writer.add_scalar("epoch/anchors", n_active, epoch+1)
writer.add_scalar("epoch/nearest_cos", tot_nearest_cos / n, epoch+1)
writer.add_scalar("epoch/push_count", push_count, epoch+1)
mk = ""
if val_acc > best_acc:
best_acc = val_acc
torch.save({
"state_dict": model.state_dict(),
"config": model.config,
"epoch": epoch + 1,
"val_acc": val_acc,
}, "checkpoints/geolip_core_best.pt")
mk = " β˜…"
nce_m = tot_nce_acc / n
cos_m = tot_nearest_cos / n
cv_band = "βœ“" if 0.18 <= v_cv <= 0.25 else "βœ—"
print(f" E{epoch+1:3d}: train={train_acc:.1f}% val={val_acc:.1f}% "
f"loss={tot_loss/n:.4f} nce={nce_m:.2f} cos={cos_m:.3f} "
f"cv={v_cv:.4f}({cv_band}) anch={n_active}/{N_ANCHORS} "
f"push={push_count} ({elapsed:.0f}s){mk}")
writer.close()
print(f"\n Best val accuracy: {best_acc:.1f}%")
print(f" Parameters: {n_params:,}")
print(f"\n{'='*60}")
print("DONE")
print(f"{'='*60}")