AbstractPhil's picture
Update tests.py
9c2a7e1 verified
"""
Flow Ensemble β€” Expanded Test Suite.
Assumes geolip-core is installed (Colab with repo loaded).
Tests: smoke, linalg integration, multi-scale, ensemble fusion,
gradient health, ablation, compile compatibility, memory.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys, time, gc
# ── Verify geolip_core.linalg is available ──
try:
import geolip_core.linalg as LA
HAS_GEOLIP_LINALG = True
print(f"geolip_core.linalg: available")
LA.backend.status()
except ImportError:
import torch.linalg as LA
HAS_GEOLIP_LINALG = False
print("geolip_core.linalg: NOT available, using torch.linalg fallback")
dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def sync():
if dev.type == 'cuda':
torch.cuda.synchronize()
def time_fn(fn, warmup=5, runs=50):
for _ in range(warmup): fn()
sync()
t0 = time.perf_counter()
for _ in range(runs): fn()
sync()
return (time.perf_counter() - t0) / runs * 1000
def fmt(ms):
if ms < 1: return f"{ms*1000:.0f}us"
return f"{ms:.2f}ms"
def make_data(B, n, k, d):
anchors = F.normalize(torch.randn(B, k, d, device=dev), dim=-1)
queries = F.normalize(torch.randn(B, n, d, device=dev), dim=-1)
return anchors, queries
# ═══════════════════════════════════════════════════════════════════
print("=" * 72)
print(" Flow Ensemble β€” Expanded Test Suite")
print("=" * 72)
print(f" device={dev} geolip_core.linalg={HAS_GEOLIP_LINALG}")
if dev.type == 'cuda':
print(f" GPU: {torch.cuda.get_device_name()}")
print()
# ═══════════════════════════════════════════════════════════════════
# 1. SMOKE TEST β€” all flows, all shapes
# ═══════════════════════════════════════════════════════════════════
print(f"{'='*72}\n 1. SMOKE TEST\n{'='*72}")
B, n, k, d = 16, 64, 32, 128
anchors, queries = make_data(B, n, k, d)
flows_cfg = [
('QuaternionFlow', lambda d,k: QuaternionFlow(d, k, n_heads=4)),
('QuaternionLiteFlow', lambda d,k: QuaternionLiteFlow(d, k)),
('VelocityFlow', lambda d,k: VelocityFlow(d, k)),
('MagnitudeFlow', lambda d,k: MagnitudeFlow(d, k)),
('OrbitalFlow', lambda d,k: OrbitalFlow(d, k)),
('AlignmentFlow', lambda d,k: AlignmentFlow(d, k)),
]
print(f"\n {'Flow':<22} {'Params':>8} {'Shape':>14} {'Time':>10} {'Conf':>8} {'Res norm':>10}")
print(f" {'─'*22} {'─'*8} {'─'*14} {'─'*10} {'─'*8} {'─'*10}")
live_flows = []
flow_ctors = []
for name, ctor in flows_cfg:
try:
flow = ctor(d, k).to(dev)
params = sum(p.numel() for p in flow.parameters())
pred, conf = flow(anchors, queries)
ms = time_fn(lambda: flow(anchors, queries))
res = (pred - queries).norm(dim=-1).mean().item()
shape_str = str(tuple(pred.shape))
print(f" {name:<22} {params:>8,} {shape_str:>14} {fmt(ms):>10} {conf.mean().item():>8.3f} {res:>10.3f}")
live_flows.append(flow)
flow_ctors.append((name, ctor))
except Exception as e:
print(f" {name:<22} FAILED: {str(e)[:50]}")
# ═══════════════════════════════════════════════════════════════════
# 2. LINALG INTEGRATION
# ═══════════════════════════════════════════════════════════════════
print(f"\n{'='*72}\n 2. LINALG INTEGRATION\n{'='*72}")
if HAS_GEOLIP_LINALG:
print(f"\n Testing eigh dispatch in MagnitudeFlow and OrbitalFlow...")
for FlowCls in [MagnitudeFlow, OrbitalFlow]:
flow = FlowCls(d, k).to(dev)
pred, conf = flow(anchors, queries)
ok = torch.isfinite(pred).all().item() and torch.isfinite(conf).all().item()
print(f" {flow.name:<18} finite={ok} conf={conf.mean():.3f}")
oflow = OrbitalFlow(d, k).to(dev)
a_geom = oflow.anchor_proj(anchors)
G = torch.bmm(a_geom.transpose(-2, -1), a_geom)
vals, vecs = LA.eigh(G)
print(f"\n Gram eigenspectrum: shape={tuple(vals.shape)} "
f"range=[{vals.min().item():.4f}, {vals.max().item():.4f}]")
print(f" Eigenvector orth err: {(torch.bmm(vecs.mT, vecs) - torch.eye(oflow.geom_dim, device=dev)).abs().max().item():.2e}")
else:
print(" Skipped β€” geolip_core.linalg not available")
# ═══════════════════════════════════════════════════════════════════
# 3. MULTI-SCALE
# ═══════════════════════════════════════════════════════════════════
print(f"\n{'='*72}\n 3. MULTI-SCALE\n{'='*72}")
configs = [
(4, 16, 8, 64, 'tiny'),
(16, 64, 32, 128, 'small'),
(32, 128, 64, 256, 'medium'),
(64, 256, 128, 256, 'large'),
(8, 512, 256, 512, 'wide'),
]
print(f"\n OrbitalFlow across scales:")
print(f" {'Config':<10} {'B':>4} {'n':>5} {'k':>5} {'d':>5} {'Time':>10} {'OK':>4}")
print(f" {'─'*10} {'─'*4} {'─'*5} {'─'*5} {'─'*5} {'─'*10} {'─'*4}")
for B_, n_, k_, d_, label in configs:
try:
of = OrbitalFlow(d_, k_).to(dev)
a, q = make_data(B_, n_, k_, d_)
pred, conf = of(a, q)
ms = time_fn(lambda: of(a, q), warmup=3, runs=20)
ok = torch.isfinite(pred).all().item()
print(f" {label:<10} {B_:>4} {n_:>5} {k_:>5} {d_:>5} {fmt(ms):>10} {'OK' if ok else 'NO':>4}")
del of, a, q
except Exception as e:
print(f" {label:<10} {B_:>4} {n_:>5} {k_:>5} {d_:>5} FAILED: {str(e)[:30]}")
# ═══════════════════════════════════════════════════════════════════
# 4. ENSEMBLE FUSION MODES
# ═══════════════════════════════════════════════════════════════════
print(f"\n{'='*72}\n 4. ENSEMBLE FUSION\n{'='*72}")
B, n, k, d = 16, 64, 32, 128
anchors, queries = make_data(B, n, k, d)
for fusion in ['weighted', 'gated', 'residual']:
ens = FlowEnsemble(live_flows, d, fusion=fusion).to(dev)
out = ens(anchors, queries)
ms = time_fn(lambda: ens(anchors, queries), warmup=3, runs=20)
preds = [flow(anchors, queries)[0] for flow in ens.flows]
cos_sims = []
for i in range(len(preds)):
for j in range(i+1, len(preds)):
cs = F.cosine_similarity(preds[i].flatten(1), preds[j].flatten(1), dim=-1).mean().item()
cos_sims.append(cs)
avg_sim = sum(cos_sims) / max(len(cos_sims), 1)
print(f"\n {fusion}: time={fmt(ms)} norm={out.norm(dim=-1).mean():.3f} diversity={1-avg_sim:.3f}")
diag = ens.flow_diagnostics(anchors, queries)
for fname, stats in diag.items():
print(f" {fname:<18} conf={stats['confidence_mean']:.3f}Β±{stats['confidence_std']:.3f} "
f"res={stats['residual_norm']:.3f}")
del ens
# ═══════════════════════════════════════════════════════════════════
# 5. GRADIENT HEALTH
# ═══════════════════════════════════════════════════════════════════
print(f"\n{'='*72}\n 5. GRADIENT HEALTH\n{'='*72}")
B, n, k, d = 16, 64, 32, 128
anchors, queries = make_data(B, n, k, d)
losses = {
'mse': (lambda o,q: (o - q).pow(2).mean()),
'cosine': (lambda o,q: (1 - F.cosine_similarity(o, q, dim=-1)).mean()),
'norm': (lambda o,q: o.norm(dim=-1).mean()),
}
print(f"\n {'Flow':<18} {'Loss':<10} {'Grad norm':>12} {'Status':>8}")
print(f" {'─'*18} {'─'*10} {'─'*12} {'─'*8}")
for loss_name, loss_fn in losses.items():
# Fresh flows for each loss β€” avoids in-place grad corruption across losses
try:
test_flows_grad = [ctor(d, k).to(dev) for _, ctor in flow_ctors]
ens_g = FlowEnsemble(test_flows_grad, d, fusion='residual').to(dev)
ens_g.zero_grad()
anchors_g = anchors.detach().clone().requires_grad_(True)
queries_g = queries.detach().clone().requires_grad_(True)
out = ens_g(anchors_g, queries_g)
loss = loss_fn(out, queries_g.detach())
loss.backward()
for flow in ens_g.flows:
grads = [p.grad for p in flow.parameters() if p.grad is not None]
if grads:
gn = torch.cat([g.flatten() for g in grads]).norm().item()
status = "OK" if 1e-8 < gn < 1e4 else "WARN"
print(f" {flow.name:<18} {loss_name:<10} {gn:>12.2e} {status:>8}")
else:
print(f" {flow.name:<18} {loss_name:<10} {'no grads':>12} {'WARN':>8}")
del ens_g, test_flows_grad
except RuntimeError as e:
if 'inplace' in str(e).lower() or 'in-place' in str(e).lower() or 'modified by' in str(e):
print(f" {'*':>18} {loss_name:<10} {'IN-PLACE ERR':>12} {'NOTE':>8}")
print(f" FL eigh deflation uses indexed assignment β€” needs .clone() fix")
else:
print(f" {'*':>18} {loss_name:<10} {'ERROR':>12}")
print(f" {str(e)[:60]}")
# ═══════════════════════════════════════════════════════════════════
# 6. ABLATION β€” solo vs pairs vs full ensemble
# ═══════════════════════════════════════════════════════════════════
print(f"\n{'='*72}\n 6. ABLATION (100 training steps, rotation target)\n{'='*72}")
B, n, k, d = 32, 128, 64, 256
anchors, queries = make_data(B, n, k, d)
R = torch.linalg.qr(torch.randn(d, d, device=dev)).Q.unsqueeze(0)
target = torch.bmm(queries, R.expand(B, -1, -1))
def eval_quality(model, anchors, queries, target, steps=100, lr=1e-3):
opt = torch.optim.Adam(model.parameters(), lr=lr)
for _ in range(steps):
opt.zero_grad()
pred = model(anchors, queries) if isinstance(model, FlowEnsemble) else model(anchors, queries)[0]
loss = (pred - target).pow(2).mean()
loss.backward()
opt.step()
with torch.no_grad():
pred = model(anchors, queries) if isinstance(model, FlowEnsemble) else model(anchors, queries)[0]
return (pred - target).pow(2).mean().item()
print(f"\n {'Configuration':<35} {'MSE':>10} {'Params':>10}")
print(f" {'─'*35} {'─'*10} {'─'*10}")
for name, ctor in flow_ctors:
try:
flow = ctor(d, k).to(dev)
params = sum(p.numel() for p in flow.parameters())
mse = eval_quality(flow, anchors, queries, target)
print(f" {name:<35} {mse:>10.4f} {params:>10,}")
del flow
except Exception as e:
print(f" {name:<35} FAILED: {str(e)[:30]}")
pairs = [
('Quat + Orbital', [0, 4]),
('Velocity + Magnitude', [2, 3]),
('Orbital + Alignment', [4, 5]),
('Velocity + Orbital', [2, 4]),
]
for pair_name, indices in pairs:
try:
pair_flows = [flow_ctors[i][1](d, k).to(dev) for i in indices if i < len(flow_ctors)]
if len(pair_flows) >= 2:
ens = FlowEnsemble(pair_flows, d, fusion='weighted').to(dev)
params = sum(p.numel() for p in ens.parameters())
mse = eval_quality(ens, anchors, queries, target)
print(f" {pair_name:<35} {mse:>10.4f} {params:>10,}")
del ens, pair_flows
except Exception as e:
print(f" {pair_name:<35} FAILED: {str(e)[:30]}")
for fusion in ['weighted', 'residual']:
try:
all_flows = [ctor(d, k).to(dev) for _, ctor in flow_ctors]
ens = FlowEnsemble(all_flows, d, fusion=fusion).to(dev)
params = sum(p.numel() for p in ens.parameters())
mse = eval_quality(ens, anchors, queries, target)
print(f" {'Full (' + fusion + ')':<35} {mse:>10.4f} {params:>10,}")
del ens, all_flows
except Exception as e:
print(f" {'Full (' + fusion + ')':<35} FAILED: {str(e)[:30]}")
# ═══════════════════════════════════════════════════════════════════
# 7. COMPILE COMPATIBILITY
# ═══════════════════════════════════════════════════════════════════
print(f"\n{'='*72}\n 7. COMPILE COMPATIBILITY\n{'='*72}")
B, n, k, d = 8, 32, 16, 64
anchors, queries = make_data(B, n, k, d)
print(f"\n {'Flow':<22} {'fullgraph':>12} {'Raw':>10} {'Compiled':>12}")
print(f" {'─'*22} {'─'*12} {'─'*10} {'─'*12}")
for name, ctor in flow_ctors:
try:
flow = ctor(d, k).to(dev)
t_raw = time_fn(lambda: flow(anchors, queries), warmup=3, runs=30)
try:
compiled = torch.compile(flow, fullgraph=True)
compiled(anchors, queries); sync()
t_comp = time_fn(lambda: compiled(anchors, queries), warmup=3, runs=30)
status = "OK"
except Exception as e:
t_comp = -1
status = str(e)[:12]
t_str = fmt(t_comp) if t_comp > 0 else "N/A"
print(f" {name:<22} {status:>12} {fmt(t_raw):>10} {t_str:>12}")
del flow
except Exception as e:
print(f" {name:<22} FAILED: {str(e)[:40]}")
# ═══════════════════════════════════════════════════════════════════
# 8. MEMORY
# ═══════════════════════════════════════════════════════════════════
if dev.type == 'cuda':
print(f"\n{'='*72}\n 8. MEMORY (B=32, n=128, k=64, d=256)\n{'='*72}")
B, n, k, d = 32, 128, 64, 256
anchors, queries = make_data(B, n, k, d)
print(f"\n {'Flow':<22} {'Peak MB':>10}")
print(f" {'─'*22} {'─'*10}")
for name, ctor in flow_ctors:
try:
flow = ctor(d, k).to(dev)
torch.cuda.empty_cache(); gc.collect()
torch.cuda.reset_peak_memory_stats()
base = torch.cuda.memory_allocated()
pred, conf = flow(anchors, queries); sync()
peak = (torch.cuda.max_memory_allocated() - base) / 1024**2
print(f" {name:<22} {peak:>9.1f}")
del flow, pred, conf
except Exception as e:
print(f" {name:<22} FAILED: {str(e)[:30]}")
try:
all_flows = [ctor(d, k).to(dev) for _, ctor in flow_ctors]
ens = FlowEnsemble(all_flows, d, fusion='weighted').to(dev)
torch.cuda.empty_cache(); gc.collect()
torch.cuda.reset_peak_memory_stats()
base = torch.cuda.memory_allocated()
out = ens(anchors, queries); sync()
peak = (torch.cuda.max_memory_allocated() - base) / 1024**2
print(f" {'Full ensemble':<22} {peak:>9.1f}")
del ens, all_flows
except Exception as e:
print(f" {'Full ensemble':<22} FAILED: {str(e)[:30]}")
print(f"\n{'='*72}")
print(f" Done.")
print(f"{'='*72}")