| """ |
| Flow Ensemble β Expanded Test Suite. |
| |
| Assumes geolip-core is installed (Colab with repo loaded). |
| Tests: smoke, linalg integration, multi-scale, ensemble fusion, |
| gradient health, ablation, compile compatibility, memory. |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import sys, time, gc |
|
|
| |
| try: |
| import geolip_core.linalg as LA |
| HAS_GEOLIP_LINALG = True |
| print(f"geolip_core.linalg: available") |
| LA.backend.status() |
| except ImportError: |
| import torch.linalg as LA |
| HAS_GEOLIP_LINALG = False |
| print("geolip_core.linalg: NOT available, using torch.linalg fallback") |
|
|
|
|
| dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
| def sync(): |
| if dev.type == 'cuda': |
| torch.cuda.synchronize() |
|
|
| def time_fn(fn, warmup=5, runs=50): |
| for _ in range(warmup): fn() |
| sync() |
| t0 = time.perf_counter() |
| for _ in range(runs): fn() |
| sync() |
| return (time.perf_counter() - t0) / runs * 1000 |
|
|
| def fmt(ms): |
| if ms < 1: return f"{ms*1000:.0f}us" |
| return f"{ms:.2f}ms" |
|
|
| def make_data(B, n, k, d): |
| anchors = F.normalize(torch.randn(B, k, d, device=dev), dim=-1) |
| queries = F.normalize(torch.randn(B, n, d, device=dev), dim=-1) |
| return anchors, queries |
|
|
|
|
| |
| print("=" * 72) |
| print(" Flow Ensemble β Expanded Test Suite") |
| print("=" * 72) |
| print(f" device={dev} geolip_core.linalg={HAS_GEOLIP_LINALG}") |
| if dev.type == 'cuda': |
| print(f" GPU: {torch.cuda.get_device_name()}") |
| print() |
|
|
|
|
| |
| |
| |
| print(f"{'='*72}\n 1. SMOKE TEST\n{'='*72}") |
|
|
| B, n, k, d = 16, 64, 32, 128 |
| anchors, queries = make_data(B, n, k, d) |
|
|
| flows_cfg = [ |
| ('QuaternionFlow', lambda d,k: QuaternionFlow(d, k, n_heads=4)), |
| ('QuaternionLiteFlow', lambda d,k: QuaternionLiteFlow(d, k)), |
| ('VelocityFlow', lambda d,k: VelocityFlow(d, k)), |
| ('MagnitudeFlow', lambda d,k: MagnitudeFlow(d, k)), |
| ('OrbitalFlow', lambda d,k: OrbitalFlow(d, k)), |
| ('AlignmentFlow', lambda d,k: AlignmentFlow(d, k)), |
| ] |
|
|
| print(f"\n {'Flow':<22} {'Params':>8} {'Shape':>14} {'Time':>10} {'Conf':>8} {'Res norm':>10}") |
| print(f" {'β'*22} {'β'*8} {'β'*14} {'β'*10} {'β'*8} {'β'*10}") |
|
|
| live_flows = [] |
| flow_ctors = [] |
| for name, ctor in flows_cfg: |
| try: |
| flow = ctor(d, k).to(dev) |
| params = sum(p.numel() for p in flow.parameters()) |
| pred, conf = flow(anchors, queries) |
| ms = time_fn(lambda: flow(anchors, queries)) |
| res = (pred - queries).norm(dim=-1).mean().item() |
| shape_str = str(tuple(pred.shape)) |
| print(f" {name:<22} {params:>8,} {shape_str:>14} {fmt(ms):>10} {conf.mean().item():>8.3f} {res:>10.3f}") |
| live_flows.append(flow) |
| flow_ctors.append((name, ctor)) |
| except Exception as e: |
| print(f" {name:<22} FAILED: {str(e)[:50]}") |
|
|
|
|
| |
| |
| |
| print(f"\n{'='*72}\n 2. LINALG INTEGRATION\n{'='*72}") |
|
|
| if HAS_GEOLIP_LINALG: |
| print(f"\n Testing eigh dispatch in MagnitudeFlow and OrbitalFlow...") |
| for FlowCls in [MagnitudeFlow, OrbitalFlow]: |
| flow = FlowCls(d, k).to(dev) |
| pred, conf = flow(anchors, queries) |
| ok = torch.isfinite(pred).all().item() and torch.isfinite(conf).all().item() |
| print(f" {flow.name:<18} finite={ok} conf={conf.mean():.3f}") |
|
|
| oflow = OrbitalFlow(d, k).to(dev) |
| a_geom = oflow.anchor_proj(anchors) |
| G = torch.bmm(a_geom.transpose(-2, -1), a_geom) |
| vals, vecs = LA.eigh(G) |
| print(f"\n Gram eigenspectrum: shape={tuple(vals.shape)} " |
| f"range=[{vals.min().item():.4f}, {vals.max().item():.4f}]") |
| print(f" Eigenvector orth err: {(torch.bmm(vecs.mT, vecs) - torch.eye(oflow.geom_dim, device=dev)).abs().max().item():.2e}") |
| else: |
| print(" Skipped β geolip_core.linalg not available") |
|
|
|
|
| |
| |
| |
| print(f"\n{'='*72}\n 3. MULTI-SCALE\n{'='*72}") |
|
|
| configs = [ |
| (4, 16, 8, 64, 'tiny'), |
| (16, 64, 32, 128, 'small'), |
| (32, 128, 64, 256, 'medium'), |
| (64, 256, 128, 256, 'large'), |
| (8, 512, 256, 512, 'wide'), |
| ] |
|
|
| print(f"\n OrbitalFlow across scales:") |
| print(f" {'Config':<10} {'B':>4} {'n':>5} {'k':>5} {'d':>5} {'Time':>10} {'OK':>4}") |
| print(f" {'β'*10} {'β'*4} {'β'*5} {'β'*5} {'β'*5} {'β'*10} {'β'*4}") |
|
|
| for B_, n_, k_, d_, label in configs: |
| try: |
| of = OrbitalFlow(d_, k_).to(dev) |
| a, q = make_data(B_, n_, k_, d_) |
| pred, conf = of(a, q) |
| ms = time_fn(lambda: of(a, q), warmup=3, runs=20) |
| ok = torch.isfinite(pred).all().item() |
| print(f" {label:<10} {B_:>4} {n_:>5} {k_:>5} {d_:>5} {fmt(ms):>10} {'OK' if ok else 'NO':>4}") |
| del of, a, q |
| except Exception as e: |
| print(f" {label:<10} {B_:>4} {n_:>5} {k_:>5} {d_:>5} FAILED: {str(e)[:30]}") |
|
|
|
|
| |
| |
| |
| print(f"\n{'='*72}\n 4. ENSEMBLE FUSION\n{'='*72}") |
|
|
| B, n, k, d = 16, 64, 32, 128 |
| anchors, queries = make_data(B, n, k, d) |
|
|
| for fusion in ['weighted', 'gated', 'residual']: |
| ens = FlowEnsemble(live_flows, d, fusion=fusion).to(dev) |
| out = ens(anchors, queries) |
| ms = time_fn(lambda: ens(anchors, queries), warmup=3, runs=20) |
|
|
| preds = [flow(anchors, queries)[0] for flow in ens.flows] |
| cos_sims = [] |
| for i in range(len(preds)): |
| for j in range(i+1, len(preds)): |
| cs = F.cosine_similarity(preds[i].flatten(1), preds[j].flatten(1), dim=-1).mean().item() |
| cos_sims.append(cs) |
| avg_sim = sum(cos_sims) / max(len(cos_sims), 1) |
|
|
| print(f"\n {fusion}: time={fmt(ms)} norm={out.norm(dim=-1).mean():.3f} diversity={1-avg_sim:.3f}") |
| diag = ens.flow_diagnostics(anchors, queries) |
| for fname, stats in diag.items(): |
| print(f" {fname:<18} conf={stats['confidence_mean']:.3f}Β±{stats['confidence_std']:.3f} " |
| f"res={stats['residual_norm']:.3f}") |
| del ens |
|
|
|
|
| |
| |
| |
| print(f"\n{'='*72}\n 5. GRADIENT HEALTH\n{'='*72}") |
|
|
| B, n, k, d = 16, 64, 32, 128 |
| anchors, queries = make_data(B, n, k, d) |
|
|
| losses = { |
| 'mse': (lambda o,q: (o - q).pow(2).mean()), |
| 'cosine': (lambda o,q: (1 - F.cosine_similarity(o, q, dim=-1)).mean()), |
| 'norm': (lambda o,q: o.norm(dim=-1).mean()), |
| } |
|
|
| print(f"\n {'Flow':<18} {'Loss':<10} {'Grad norm':>12} {'Status':>8}") |
| print(f" {'β'*18} {'β'*10} {'β'*12} {'β'*8}") |
|
|
| for loss_name, loss_fn in losses.items(): |
| |
| try: |
| test_flows_grad = [ctor(d, k).to(dev) for _, ctor in flow_ctors] |
| ens_g = FlowEnsemble(test_flows_grad, d, fusion='residual').to(dev) |
| ens_g.zero_grad() |
| anchors_g = anchors.detach().clone().requires_grad_(True) |
| queries_g = queries.detach().clone().requires_grad_(True) |
| out = ens_g(anchors_g, queries_g) |
| loss = loss_fn(out, queries_g.detach()) |
| loss.backward() |
|
|
| for flow in ens_g.flows: |
| grads = [p.grad for p in flow.parameters() if p.grad is not None] |
| if grads: |
| gn = torch.cat([g.flatten() for g in grads]).norm().item() |
| status = "OK" if 1e-8 < gn < 1e4 else "WARN" |
| print(f" {flow.name:<18} {loss_name:<10} {gn:>12.2e} {status:>8}") |
| else: |
| print(f" {flow.name:<18} {loss_name:<10} {'no grads':>12} {'WARN':>8}") |
| del ens_g, test_flows_grad |
| except RuntimeError as e: |
| if 'inplace' in str(e).lower() or 'in-place' in str(e).lower() or 'modified by' in str(e): |
| print(f" {'*':>18} {loss_name:<10} {'IN-PLACE ERR':>12} {'NOTE':>8}") |
| print(f" FL eigh deflation uses indexed assignment β needs .clone() fix") |
| else: |
| print(f" {'*':>18} {loss_name:<10} {'ERROR':>12}") |
| print(f" {str(e)[:60]}") |
|
|
|
|
| |
| |
| |
| print(f"\n{'='*72}\n 6. ABLATION (100 training steps, rotation target)\n{'='*72}") |
|
|
| B, n, k, d = 32, 128, 64, 256 |
| anchors, queries = make_data(B, n, k, d) |
| R = torch.linalg.qr(torch.randn(d, d, device=dev)).Q.unsqueeze(0) |
| target = torch.bmm(queries, R.expand(B, -1, -1)) |
|
|
| def eval_quality(model, anchors, queries, target, steps=100, lr=1e-3): |
| opt = torch.optim.Adam(model.parameters(), lr=lr) |
| for _ in range(steps): |
| opt.zero_grad() |
| pred = model(anchors, queries) if isinstance(model, FlowEnsemble) else model(anchors, queries)[0] |
| loss = (pred - target).pow(2).mean() |
| loss.backward() |
| opt.step() |
| with torch.no_grad(): |
| pred = model(anchors, queries) if isinstance(model, FlowEnsemble) else model(anchors, queries)[0] |
| return (pred - target).pow(2).mean().item() |
|
|
| print(f"\n {'Configuration':<35} {'MSE':>10} {'Params':>10}") |
| print(f" {'β'*35} {'β'*10} {'β'*10}") |
|
|
| for name, ctor in flow_ctors: |
| try: |
| flow = ctor(d, k).to(dev) |
| params = sum(p.numel() for p in flow.parameters()) |
| mse = eval_quality(flow, anchors, queries, target) |
| print(f" {name:<35} {mse:>10.4f} {params:>10,}") |
| del flow |
| except Exception as e: |
| print(f" {name:<35} FAILED: {str(e)[:30]}") |
|
|
| pairs = [ |
| ('Quat + Orbital', [0, 4]), |
| ('Velocity + Magnitude', [2, 3]), |
| ('Orbital + Alignment', [4, 5]), |
| ('Velocity + Orbital', [2, 4]), |
| ] |
| for pair_name, indices in pairs: |
| try: |
| pair_flows = [flow_ctors[i][1](d, k).to(dev) for i in indices if i < len(flow_ctors)] |
| if len(pair_flows) >= 2: |
| ens = FlowEnsemble(pair_flows, d, fusion='weighted').to(dev) |
| params = sum(p.numel() for p in ens.parameters()) |
| mse = eval_quality(ens, anchors, queries, target) |
| print(f" {pair_name:<35} {mse:>10.4f} {params:>10,}") |
| del ens, pair_flows |
| except Exception as e: |
| print(f" {pair_name:<35} FAILED: {str(e)[:30]}") |
|
|
| for fusion in ['weighted', 'residual']: |
| try: |
| all_flows = [ctor(d, k).to(dev) for _, ctor in flow_ctors] |
| ens = FlowEnsemble(all_flows, d, fusion=fusion).to(dev) |
| params = sum(p.numel() for p in ens.parameters()) |
| mse = eval_quality(ens, anchors, queries, target) |
| print(f" {'Full (' + fusion + ')':<35} {mse:>10.4f} {params:>10,}") |
| del ens, all_flows |
| except Exception as e: |
| print(f" {'Full (' + fusion + ')':<35} FAILED: {str(e)[:30]}") |
|
|
|
|
| |
| |
| |
| print(f"\n{'='*72}\n 7. COMPILE COMPATIBILITY\n{'='*72}") |
|
|
| B, n, k, d = 8, 32, 16, 64 |
| anchors, queries = make_data(B, n, k, d) |
|
|
| print(f"\n {'Flow':<22} {'fullgraph':>12} {'Raw':>10} {'Compiled':>12}") |
| print(f" {'β'*22} {'β'*12} {'β'*10} {'β'*12}") |
|
|
| for name, ctor in flow_ctors: |
| try: |
| flow = ctor(d, k).to(dev) |
| t_raw = time_fn(lambda: flow(anchors, queries), warmup=3, runs=30) |
| try: |
| compiled = torch.compile(flow, fullgraph=True) |
| compiled(anchors, queries); sync() |
| t_comp = time_fn(lambda: compiled(anchors, queries), warmup=3, runs=30) |
| status = "OK" |
| except Exception as e: |
| t_comp = -1 |
| status = str(e)[:12] |
| t_str = fmt(t_comp) if t_comp > 0 else "N/A" |
| print(f" {name:<22} {status:>12} {fmt(t_raw):>10} {t_str:>12}") |
| del flow |
| except Exception as e: |
| print(f" {name:<22} FAILED: {str(e)[:40]}") |
|
|
|
|
| |
| |
| |
| if dev.type == 'cuda': |
| print(f"\n{'='*72}\n 8. MEMORY (B=32, n=128, k=64, d=256)\n{'='*72}") |
|
|
| B, n, k, d = 32, 128, 64, 256 |
| anchors, queries = make_data(B, n, k, d) |
|
|
| print(f"\n {'Flow':<22} {'Peak MB':>10}") |
| print(f" {'β'*22} {'β'*10}") |
|
|
| for name, ctor in flow_ctors: |
| try: |
| flow = ctor(d, k).to(dev) |
| torch.cuda.empty_cache(); gc.collect() |
| torch.cuda.reset_peak_memory_stats() |
| base = torch.cuda.memory_allocated() |
| pred, conf = flow(anchors, queries); sync() |
| peak = (torch.cuda.max_memory_allocated() - base) / 1024**2 |
| print(f" {name:<22} {peak:>9.1f}") |
| del flow, pred, conf |
| except Exception as e: |
| print(f" {name:<22} FAILED: {str(e)[:30]}") |
|
|
| try: |
| all_flows = [ctor(d, k).to(dev) for _, ctor in flow_ctors] |
| ens = FlowEnsemble(all_flows, d, fusion='weighted').to(dev) |
| torch.cuda.empty_cache(); gc.collect() |
| torch.cuda.reset_peak_memory_stats() |
| base = torch.cuda.memory_allocated() |
| out = ens(anchors, queries); sync() |
| peak = (torch.cuda.max_memory_allocated() - base) / 1024**2 |
| print(f" {'Full ensemble':<22} {peak:>9.1f}") |
| del ens, all_flows |
| except Exception as e: |
| print(f" {'Full ensemble':<22} FAILED: {str(e)[:30]}") |
|
|
| print(f"\n{'='*72}") |
| print(f" Done.") |
| print(f"{'='*72}") |