AbstractPhil commited on
Commit
fc29ddb
Β·
verified Β·
1 Parent(s): d8a6cd0

Create cv_sweep_mha_testing.py

Browse files
Files changed (1) hide show
  1. cv_sweep_mha_testing.py +338 -0
cv_sweep_mha_testing.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MHA CV Relational Test β€” Prototype
3
+ Train a minimal embedding + MHA + classifier on 10 noise patterns.
4
+ Measure CV on embedding weights, Q/K/V projections, and attention output
5
+ across different head counts per embedding dimension.
6
+
7
+ Hypothesis: head_dim (D / n_heads) determines CV of internal representations,
8
+ and the band-valid head_dims produce qualitatively different geometric behavior.
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import math
15
+
16
+
17
+ # ── CM primitives ──
18
+
19
+ def cayley_menger_vol2(points):
20
+ B, N, D = points.shape
21
+ gram = torch.bmm(points, points.transpose(1, 2))
22
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
23
+ d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)
24
+ cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=points.dtype)
25
+ cm[:, 0, 1:] = 1.0
26
+ cm[:, 1:, 0] = 1.0
27
+ cm[:, 1:, 1:] = d2
28
+ k = N - 1
29
+ sign = (-1.0) ** (k + 1)
30
+ fact = math.factorial(k)
31
+ return sign * torch.linalg.det(cm.float()).to(points.dtype) / ((2 ** k) * (fact ** 2))
32
+
33
+
34
+ def cv_metric(weight, n_samples=300):
35
+ """CV of pentachoron volumes. weight: (N, D)"""
36
+ V, D = weight.shape
37
+ if V < 5:
38
+ return None
39
+ pool = min(V, 512)
40
+ indices = torch.stack([
41
+ torch.randperm(pool, device=weight.device)[:5]
42
+ for _ in range(n_samples)
43
+ ])
44
+ pts = weight[:pool][indices]
45
+ vol2 = cayley_menger_vol2(pts)
46
+ valid = vol2 > 1e-20
47
+ if valid.sum() < 10:
48
+ return None
49
+ vols = vol2[valid].sqrt()
50
+ return (vols.std() / (vols.mean() + 1e-8)).item()
51
+
52
+
53
+ # ── Minimal model ──
54
+
55
+ class MHAClassifier(nn.Module):
56
+ def __init__(self, vocab, dim, n_heads, seq_len, n_classes):
57
+ super().__init__()
58
+ self.emb = nn.Embedding(vocab, dim)
59
+ self.pos = nn.Parameter(torch.randn(1, seq_len, dim) * 0.02)
60
+ self.mha = nn.MultiheadAttention(dim, n_heads, batch_first=True)
61
+ self.norm = nn.LayerNorm(dim)
62
+ self.head = nn.Linear(dim, n_classes)
63
+
64
+ def forward(self, x):
65
+ # x: (B, seq_len) token indices
66
+ h = self.emb(x) + self.pos
67
+ attn_out, _ = self.mha(h, h, h)
68
+ h = self.norm(h + attn_out)
69
+ # pool over sequence
70
+ h = h.mean(dim=1)
71
+ return self.head(h)
72
+
73
+ @torch.no_grad()
74
+ def forward_activations(self, x, n_heads):
75
+ """Forward pass returning per-head Q/K/V activations and post-attn output.
76
+
77
+ Returns dict of (B*seq, head_dim) tensors for CV measurement.
78
+ """
79
+ h = self.emb(x) + self.pos # (B, seq, D)
80
+ B, S, D = h.shape
81
+ head_dim = D // n_heads
82
+
83
+ # Manually compute Q, K, V from in_proj
84
+ w = self.mha.in_proj_weight
85
+ b = self.mha.in_proj_bias
86
+ qkv = F.linear(h, w, b) # (B, seq, 3*D)
87
+ q, k, v = qkv.chunk(3, dim=-1) # each (B, seq, D)
88
+
89
+ # Reshape to per-head: (B, seq, n_heads, head_dim)
90
+ q = q.view(B, S, n_heads, head_dim)
91
+ k = k.view(B, S, n_heads, head_dim)
92
+ v = v.view(B, S, n_heads, head_dim)
93
+
94
+ # Compute attention output
95
+ attn_out, _ = self.mha(h, h, h)
96
+ post_attn = self.norm(h + attn_out) # (B, seq, D)
97
+ # Post-attn per head view
98
+ post_heads = post_attn.view(B, S, n_heads, head_dim)
99
+
100
+ acts = {}
101
+ for i in range(n_heads):
102
+ acts[f"act_Q_h{i}"] = q[:, :, i, :].reshape(-1, head_dim)
103
+ acts[f"act_K_h{i}"] = k[:, :, i, :].reshape(-1, head_dim)
104
+ acts[f"act_V_h{i}"] = v[:, :, i, :].reshape(-1, head_dim)
105
+ acts[f"act_post_h{i}"] = post_heads[:, :, i, :].reshape(-1, head_dim)
106
+
107
+ # Also full-dim activations
108
+ acts["act_emb"] = h.reshape(-1, D)
109
+ acts["act_post_full"] = post_attn.reshape(-1, D)
110
+
111
+ return acts
112
+
113
+ def get_qkv_weights(self):
114
+ """Extract Q, K, V projection weight matrices."""
115
+ # nn.MultiheadAttention packs Q, K, V into in_proj_weight: (3*dim, dim)
116
+ w = self.mha.in_proj_weight.detach()
117
+ d = w.shape[1]
118
+ q_w = w[:d] # (dim, dim)
119
+ k_w = w[d:2*d] # (dim, dim)
120
+ v_w = w[2*d:] # (dim, dim)
121
+ return q_w, k_w, v_w
122
+
123
+ def get_per_head_projections(self, n_heads):
124
+ """Split Q/K/V weights into per-head chunks. Returns list of (head_dim, dim) per head."""
125
+ q_w, k_w, v_w = self.get_qkv_weights()
126
+ d = q_w.shape[0]
127
+ head_dim = d // n_heads
128
+ q_heads = [q_w[i*head_dim:(i+1)*head_dim] for i in range(n_heads)]
129
+ k_heads = [k_w[i*head_dim:(i+1)*head_dim] for i in range(n_heads)]
130
+ v_heads = [v_w[i*head_dim:(i+1)*head_dim] for i in range(n_heads)]
131
+ return q_heads, k_heads, v_heads
132
+
133
+
134
+ # ── Data: 10 noise patterns with perturbations ──
135
+
136
+ def make_data(n_classes=10, samples_per_class=50, seq_len=8, vocab=256):
137
+ """Create simple classification data. Each class has a base token pattern with noise."""
138
+ torch.manual_seed(42)
139
+ # Base patterns: each class gets a fixed token sequence
140
+ base_patterns = torch.randint(0, vocab, (n_classes, seq_len))
141
+
142
+ all_x, all_y = [], []
143
+ for cls in range(n_classes):
144
+ for _ in range(samples_per_class):
145
+ pattern = base_patterns[cls].clone()
146
+ # Perturb ~25% of positions
147
+ mask = torch.rand(seq_len) < 0.25
148
+ pattern[mask] = torch.randint(0, vocab, (mask.sum(),))
149
+ all_x.append(pattern)
150
+ all_y.append(cls)
151
+
152
+ x = torch.stack(all_x)
153
+ y = torch.tensor(all_y)
154
+ perm = torch.randperm(len(x))
155
+ return x[perm], y[perm]
156
+
157
+
158
+ # ── CV measurement suite ──
159
+
160
+ def measure_all_cv(model, n_heads, x=None):
161
+ """Measure CV on all relevant weight matrices and activations."""
162
+ results = {}
163
+
164
+ # Embedding weights
165
+ emb_w = model.emb.weight.detach()
166
+ results["emb"] = cv_metric(emb_w)
167
+
168
+ # Full Q, K, V projection matrices (dim Γ— dim)
169
+ q_w, k_w, v_w = model.get_qkv_weights()
170
+ results["Q_full"] = cv_metric(q_w)
171
+ results["K_full"] = cv_metric(k_w)
172
+ results["V_full"] = cv_metric(v_w)
173
+
174
+ # Per-head projections (head_dim Γ— dim) β€” CV measured on head_dim rows
175
+ q_heads, k_heads, v_heads = model.get_per_head_projections(n_heads)
176
+ for i in range(n_heads):
177
+ results[f"Q_h{i}"] = cv_metric(q_heads[i])
178
+ results[f"K_h{i}"] = cv_metric(k_heads[i])
179
+ results[f"V_h{i}"] = cv_metric(v_heads[i])
180
+
181
+ # Output projection
182
+ out_w = model.mha.out_proj.weight.detach()
183
+ results["out_proj"] = cv_metric(out_w)
184
+
185
+ # Classifier head
186
+ head_w = model.head.weight.detach()
187
+ results["cls_head"] = cv_metric(head_w)
188
+
189
+ # Activations β€” the space where attention actually operates
190
+ if x is not None:
191
+ model.eval()
192
+ acts = model.forward_activations(x, n_heads)
193
+ for name, tensor in acts.items():
194
+ results[name] = cv_metric(tensor)
195
+
196
+ return results
197
+
198
+
199
+ def fmt_cv(cv):
200
+ if cv is None:
201
+ return " N/A "
202
+ band = "*" if 0.13 < cv < 0.30 else " "
203
+ return f"{band}{cv:.4f}{band}"
204
+
205
+
206
+ # ── Training + measurement loop ──
207
+
208
+ def run_experiment(dim, n_heads, vocab=256, seq_len=8, n_classes=10, epochs=50, lr=1e-3):
209
+ head_dim = dim // n_heads
210
+ print(f"\n{'='*70}")
211
+ print(f"D={dim} heads={n_heads} head_dim={head_dim}")
212
+ print(f"{'='*70}")
213
+
214
+ x, y = make_data(n_classes=n_classes, seq_len=seq_len, vocab=vocab)
215
+ model = MHAClassifier(vocab, dim, n_heads, seq_len, n_classes)
216
+ opt = torch.optim.Adam(model.parameters(), lr=lr)
217
+
218
+ # Pre-training CV
219
+ print(f"\n [pre-train]")
220
+ pre_cv = measure_all_cv(model, n_heads, x)
221
+ for k, v in pre_cv.items():
222
+ print(f" {k:16s}: {fmt_cv(v)}")
223
+
224
+ # Training
225
+ mid_cv = None
226
+ for epoch in range(1, epochs + 1):
227
+ model.train()
228
+ opt.zero_grad()
229
+ logits = model(x)
230
+ loss = F.cross_entropy(logits, y)
231
+ loss.backward()
232
+ opt.step()
233
+
234
+ if epoch == epochs // 2:
235
+ model.eval()
236
+ with torch.no_grad():
237
+ acc = (model(x).argmax(-1) == y).float().mean().item()
238
+ mid_cv = measure_all_cv(model, n_heads, x)
239
+ print(f"\n [epoch {epoch}] loss={loss.item():.4f} acc={acc:.2%}")
240
+ for k, v in mid_cv.items():
241
+ print(f" {k:16s}: {fmt_cv(v)}")
242
+
243
+ # Post-training CV
244
+ model.eval()
245
+ with torch.no_grad():
246
+ acc = (model(x).argmax(-1) == y).float().mean().item()
247
+ print(f"\n [post-train] loss={loss.item():.4f} acc={acc:.2%}")
248
+ post_cv = measure_all_cv(model, n_heads, x)
249
+ for k, v in post_cv.items():
250
+ pre = pre_cv.get(k)
251
+ delta = ""
252
+ if v is not None and pre is not None:
253
+ d = v - pre
254
+ delta = f" Ξ”={d:+.4f}"
255
+ print(f" {k:16s}: {fmt_cv(v)}{delta}")
256
+
257
+ return {
258
+ "dim": dim, "n_heads": n_heads, "head_dim": head_dim,
259
+ "pre": pre_cv, "mid": mid_cv, "post": post_cv, "acc": acc,
260
+ }
261
+
262
+
263
+ # ── Main ──
264
+
265
+ if __name__ == "__main__":
266
+ print("MHA CV Relational Test β€” Prototype")
267
+ print("Band: 0.13 < CV < 0.30")
268
+
269
+ configs = [
270
+ # D=64: head_dims 64, 32, 16, 8
271
+ (64, 1),
272
+ (64, 2),
273
+ (64, 4),
274
+ (64, 8),
275
+ # D=128: head_dims 128, 64, 32, 16
276
+ (128, 1),
277
+ (128, 2),
278
+ (128, 4),
279
+ (128, 8),
280
+ # D=256: head_dims 256, 128, 64, 32
281
+ (256, 1),
282
+ (256, 2),
283
+ (256, 4),
284
+ (256, 8),
285
+ ]
286
+
287
+ all_results = []
288
+ for dim, n_heads in configs:
289
+ r = run_experiment(dim, n_heads)
290
+ all_results.append(r)
291
+
292
+ # Summary β€” Weights
293
+ print(f"\n\n{'='*70}")
294
+ print("SUMMARY: Post-training WEIGHT CV by head_dim")
295
+ print(f"{'='*70}")
296
+ print(f"{'D':>5} {'heads':>5} {'hdim':>5} | {'emb':>8} {'Q_full':>8} {'K_full':>8} {'V_full':>8} {'out':>8} | acc")
297
+ print("-" * 80)
298
+ for r in all_results:
299
+ p = r["post"]
300
+ print(f"{r['dim']:5d} {r['n_heads']:5d} {r['head_dim']:5d} | "
301
+ f"{fmt_cv(p.get('emb')):>8} {fmt_cv(p.get('Q_full')):>8} "
302
+ f"{fmt_cv(p.get('K_full')):>8} {fmt_cv(p.get('V_full')):>8} "
303
+ f"{fmt_cv(p.get('out_proj')):>8} | {r['acc']:.2%}")
304
+
305
+ # Summary β€” Activations (the real test)
306
+ print(f"\n\n{'='*70}")
307
+ print("SUMMARY: Post-training ACTIVATION CV by head_dim")
308
+ print("(These measure the space where attention actually operates)")
309
+ print(f"{'='*70}")
310
+ print(f"{'D':>5} {'heads':>5} {'hdim':>5} | {'act_emb':>8} {'aQ_h0':>8} {'aK_h0':>8} {'aV_h0':>8} {'aPost0':>8} {'act_full':>8} | acc")
311
+ print("-" * 90)
312
+ for r in all_results:
313
+ p = r["post"]
314
+ print(f"{r['dim']:5d} {r['n_heads']:5d} {r['head_dim']:5d} | "
315
+ f"{fmt_cv(p.get('act_emb')):>8} "
316
+ f"{fmt_cv(p.get('act_Q_h0')):>8} {fmt_cv(p.get('act_K_h0')):>8} "
317
+ f"{fmt_cv(p.get('act_V_h0')):>8} {fmt_cv(p.get('act_post_h0')):>8} "
318
+ f"{fmt_cv(p.get('act_post_full')):>8} | {r['acc']:.2%}")
319
+
320
+ # Summary — Activation CV delta (pre→post)
321
+ print(f"\n\n{'='*70}")
322
+ print("SUMMARY: ACTIVATION CV movement (post - pre)")
323
+ print(f"{'='*70}")
324
+ print(f"{'D':>5} {'heads':>5} {'hdim':>5} | {'act_emb':>8} {'aQ_h0':>8} {'aK_h0':>8} {'aV_h0':>8} {'aPost0':>8} {'act_full':>8}")
325
+ print("-" * 80)
326
+ for r in all_results:
327
+ pre, post = r["pre"], r["post"]
328
+ def delta(k):
329
+ a, b = pre.get(k), post.get(k)
330
+ if a is not None and b is not None:
331
+ d = b - a
332
+ return f"{d:+.4f}"
333
+ return " N/A "
334
+ print(f"{r['dim']:5d} {r['n_heads']:5d} {r['head_dim']:5d} | "
335
+ f"{delta('act_emb'):>8} "
336
+ f"{delta('act_Q_h0'):>8} {delta('act_K_h0'):>8} "
337
+ f"{delta('act_V_h0'):>8} {delta('act_post_h0'):>8} "
338
+ f"{delta('act_post_full'):>8}")