theapemachine commited on
Commit
de14582
Β·
verified Β·
1 Parent(s): 7853236

Upload ablations.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ablations.py +824 -1
ablations.py CHANGED
@@ -1 +1,824 @@
1
- See file content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Sparse Transformer: Definitive Ablation Suite
4
+
5
+ Builds on v18_fast_knn_triton.py. Addresses all three structural gaps
6
+ identified in the critique:
7
+
8
+ 1. PHANTOM MOMENTUM ABLATION
9
+ - "phantom": standard Adam β€” inactive chunks' moments decay on zero grad (default)
10
+ - "frozen": inactive chunks' Adam state (m, v) is completely frozen
11
+ Compare across all schedulers to isolate whether convergence is driven
12
+ by the chunking algorithm or by phantom momentum acting as regularization.
13
+
14
+ 2. COMPUTE-MATCHED BASELINES
15
+ - Dense at same steps (standard comparison)
16
+ - Dense at fewer steps matching sparse FLOPs
17
+ - Natively smaller dense model matching sparse active capacity
18
+
19
+ 3. UNIFIED HARDWARE
20
+ Everything on CUDA (A10G). Single hardware stack.
21
+
22
+ Plus: KNN vs EMA vs Random vs Oracle predictor comparison with proper
23
+ oracle overlap measurement.
24
+
25
+ Run:
26
+ python ablations.py --device cuda --steps 1000 --n_embd 1024 --experiment all
27
+ python ablations.py --device cuda --experiment phantom_momentum
28
+ python ablations.py --device cuda --experiment compute_matched
29
+ python ablations.py --device cuda --experiment predictor_accuracy
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ import argparse
35
+ import json
36
+ import math
37
+ import os
38
+ import random
39
+ import sys
40
+ import time
41
+ from collections import defaultdict
42
+ from typing import Dict, List, Literal, Optional, Tuple
43
+
44
+ import torch
45
+ import torch.nn as nn
46
+ import torch.nn.functional as F
47
+
48
+ try:
49
+ import triton
50
+ import triton.language as tl
51
+ HAS_TRITON = True
52
+ except ImportError:
53
+ HAS_TRITON = False
54
+
55
+ try:
56
+ import tiktoken
57
+ HAS_TIKTOKEN = True
58
+ except ImportError:
59
+ HAS_TIKTOKEN = False
60
+
61
+ # ═══════════════════════════════════════════════════════════════
62
+ # TRITON KERNELS (from v18_triton, no autotune, block_ptr)
63
+ # ═══════════════════════════════════════════════════════════════
64
+
65
+ if HAS_TRITON:
66
+ @triton.jit
67
+ def _sparse_bwd_dW_db_kernel(
68
+ X_ptr, dY_ptr, dW_ptr, dB_ptr, chunk_ids_ptr,
69
+ M: tl.constexpr, d_in: tl.constexpr, d_out: tl.constexpr,
70
+ num_active: tl.constexpr,
71
+ stride_xm: tl.constexpr, stride_xk: tl.constexpr,
72
+ stride_dym: tl.constexpr, stride_dyn: tl.constexpr,
73
+ stride_dwn: tl.constexpr, stride_dwk: tl.constexpr,
74
+ HAS_BIAS: tl.constexpr,
75
+ CS: tl.constexpr, BK: tl.constexpr, BM: tl.constexpr,
76
+ ):
77
+ cli = tl.program_id(0)
78
+ kbi = tl.program_id(1)
79
+ cidx = tl.load(chunk_ids_ptr + cli)
80
+ cs0 = cidx * CS
81
+ ko = kbi * BK
82
+
83
+ dy_bp = tl.make_block_ptr(dY_ptr, (d_out, M), (stride_dyn, stride_dym),
84
+ (cs0, 0), (CS, BM), (1, 0))
85
+ x_bp = tl.make_block_ptr(X_ptr, (M, d_in), (stride_xm, stride_xk),
86
+ (0, ko), (BM, BK), (1, 0))
87
+
88
+ acc = tl.zeros((CS, BK), dtype=tl.float32)
89
+ do_bias = HAS_BIAS and (kbi == 0)
90
+ acc_b = tl.zeros((CS,), dtype=tl.float32)
91
+
92
+ for _ in range(0, M, BM):
93
+ dy_t = tl.load(dy_bp, boundary_check=(0, 1))
94
+ x = tl.load(x_bp, boundary_check=(0, 1))
95
+ acc = tl.dot(dy_t, x, acc=acc)
96
+ if do_bias:
97
+ acc_b += tl.sum(dy_t, axis=1)
98
+ dy_bp = tl.advance(dy_bp, (0, BM))
99
+ x_bp = tl.advance(x_bp, (BM, 0))
100
+
101
+ dw_bp = tl.make_block_ptr(dW_ptr, (d_out, d_in), (stride_dwn, stride_dwk),
102
+ (cs0, ko), (CS, BK), (1, 0))
103
+ tl.store(dw_bp, acc.to(dW_ptr.dtype.element_ty), boundary_check=(0, 1))
104
+
105
+ if do_bias:
106
+ rn = cs0 + tl.arange(0, CS)
107
+ tl.store(dB_ptr + rn, acc_b.to(dB_ptr.dtype.element_ty), mask=rn < d_out)
108
+
109
+ @triton.jit
110
+ def _sparse_bwd_dX_kernel(
111
+ dY_ptr, W_ptr, dX_ptr, chunk_ids_ptr,
112
+ M: tl.constexpr, d_in: tl.constexpr, d_out: tl.constexpr,
113
+ num_active: tl.constexpr,
114
+ stride_dym: tl.constexpr, stride_dyn: tl.constexpr,
115
+ stride_wn: tl.constexpr, stride_wk: tl.constexpr,
116
+ stride_dxm: tl.constexpr, stride_dxk: tl.constexpr,
117
+ CS: tl.constexpr, BM: tl.constexpr, BK: tl.constexpr,
118
+ ):
119
+ pm = tl.program_id(0)
120
+ pk = tl.program_id(1)
121
+ mo = pm * BM
122
+ ko = pk * BK
123
+ acc = tl.zeros((BM, BK), dtype=tl.float32)
124
+ for i in range(0, num_active):
125
+ cidx = tl.load(chunk_ids_ptr + i)
126
+ cs0 = cidx * CS
127
+ dy_bp = tl.make_block_ptr(dY_ptr, (M, d_out), (stride_dym, stride_dyn),
128
+ (mo, cs0), (BM, CS), (1, 0))
129
+ w_bp = tl.make_block_ptr(W_ptr, (d_out, d_in), (stride_wn, stride_wk),
130
+ (cs0, ko), (CS, BK), (1, 0))
131
+ dy = tl.load(dy_bp, boundary_check=(0, 1))
132
+ w = tl.load(w_bp, boundary_check=(0, 1))
133
+ acc = tl.dot(dy, w, acc=acc)
134
+ dx_bp = tl.make_block_ptr(dX_ptr, (M, d_in), (stride_dxm, stride_dxk),
135
+ (mo, ko), (BM, BK), (1, 0))
136
+ tl.store(dx_bp, acc.to(dX_ptr.dtype.element_ty), boundary_check=(0, 1))
137
+
138
+
139
+ def triton_bwd_dW_db(xf, gyf, active, cs, d_out, has_bias):
140
+ M, d_in = xf.shape
141
+ na = active.numel()
142
+ dW = torch.zeros(d_out, d_in, device=xf.device, dtype=xf.dtype)
143
+ dB = torch.zeros(d_out, device=xf.device, dtype=xf.dtype) if has_bias else None
144
+ if na == 0: return dW, dB
145
+ cids = active.to(torch.int32).contiguous()
146
+ BK, BM = 64, 64
147
+ _sparse_bwd_dW_db_kernel[(na, triton.cdiv(d_in, BK))](
148
+ xf, gyf, dW, dB if has_bias else dW, cids,
149
+ M, d_in, d_out, na,
150
+ xf.stride(0), xf.stride(1), gyf.stride(0), gyf.stride(1),
151
+ dW.stride(0), dW.stride(1),
152
+ HAS_BIAS=has_bias, CS=cs, BK=BK, BM=BM, num_warps=4)
153
+ return dW, dB
154
+
155
+ def triton_bwd_dX(gyf, w, active, cs, M, d_in):
156
+ na = active.numel()
157
+ d_out = gyf.shape[1]
158
+ dX = torch.zeros(M, d_in, device=gyf.device, dtype=gyf.dtype)
159
+ if na == 0: return dX
160
+ cids = active.to(torch.int32).contiguous()
161
+ BM, BK = 64, 64
162
+ _sparse_bwd_dX_kernel[(triton.cdiv(M, BM), triton.cdiv(d_in, BK))](
163
+ gyf, w, dX, cids,
164
+ M, d_in, d_out, na,
165
+ gyf.stride(0), gyf.stride(1), w.stride(0), w.stride(1),
166
+ dX.stride(0), dX.stride(1),
167
+ CS=cs, BM=BM, BK=BK, num_warps=4)
168
+ return dX
169
+
170
+ # ═══════════════════════════════════════════════════════════════
171
+ # AUTOGRAD
172
+ # ═══════════════════════════════════════════════════════════════
173
+
174
+ class TritonSparseLinearFn(torch.autograd.Function):
175
+ @staticmethod
176
+ def forward(ctx, x, w, b, active, cs, sparse_dx):
177
+ ctx.save_for_backward(x, w, active)
178
+ ctx.has_bias = b is not None
179
+ ctx.sparse_dx = sparse_dx
180
+ ctx.cs = cs
181
+ return F.linear(x, w, b)
182
+
183
+ @staticmethod
184
+ def backward(ctx, gy):
185
+ x, w, active = ctx.saved_tensors
186
+ cs = ctx.cs
187
+ do, di = w.shape
188
+ xf = x.reshape(-1, di).contiguous()
189
+ gf = gy.reshape(-1, do).contiguous()
190
+ M = xf.shape[0]
191
+ gw, gb = triton_bwd_dW_db(xf, gf, active, cs, do, ctx.has_bias)
192
+ gx = triton_bwd_dX(gf, w.contiguous(), active, cs, M, di) if ctx.sparse_dx else gf @ w
193
+ return gx.reshape(x.shape), gw, gb, None, None, None
194
+
195
+ class PyLoopSparseLinearFn(torch.autograd.Function):
196
+ @staticmethod
197
+ def forward(ctx, x, w, b, active, cs, sparse_dx):
198
+ ctx.save_for_backward(x, w, active)
199
+ ctx.has_bias = b is not None
200
+ ctx.sparse_dx = sparse_dx
201
+ ctx.cs = cs
202
+ return F.linear(x, w, b)
203
+
204
+ @staticmethod
205
+ def backward(ctx, gy):
206
+ x, w, active = ctx.saved_tensors
207
+ cs = ctx.cs
208
+ xf = x.reshape(-1, x.shape[-1])
209
+ gf = gy.reshape(-1, gy.shape[-1])
210
+ gw = torch.zeros_like(w)
211
+ gb = torch.zeros(w.shape[0], device=w.device, dtype=w.dtype) if ctx.has_bias else None
212
+ gx = torch.zeros_like(xf) if ctx.sparse_dx else gf @ w
213
+ for c in active.tolist():
214
+ s, e = c * cs, (c+1) * cs
215
+ sl = gf[:, s:e]
216
+ gw[s:e] = sl.t() @ xf
217
+ if gb is not None: gb[s:e] = sl.sum(0)
218
+ if ctx.sparse_dx: gx += sl @ w[s:e]
219
+ return gx.reshape(x.shape), gw, gb, None, None, None
220
+
221
+ # ═══════════════════════════════════════════════════════════════
222
+ # MODEL
223
+ # ═══════════════════════════════════════════════════════════════
224
+
225
+ class SparseLinear(nn.Linear):
226
+ def __init__(self, inf, outf, bias=True):
227
+ super().__init__(inf, outf, bias=bias)
228
+ self.sparse_enabled = False
229
+ self.sparse_dx = False
230
+ self.active_chunks = None
231
+ self.chunk_size = 64
232
+ self.backend = "triton" # "triton" or "torch"
233
+
234
+ def forward(self, x):
235
+ if not self.sparse_enabled or self.active_chunks is None:
236
+ return F.linear(x, self.weight, self.bias)
237
+ fn = TritonSparseLinearFn if (self.backend == "triton" and HAS_TRITON) else PyLoopSparseLinearFn
238
+ return fn.apply(x, self.weight, self.bias, self.active_chunks, self.chunk_size, self.sparse_dx)
239
+
240
+ class Attn(nn.Module):
241
+ def __init__(self, d, nh, bs, do):
242
+ super().__init__()
243
+ self.nh, self.hd = nh, d // nh
244
+ self.c_attn = SparseLinear(d, 3*d)
245
+ self.c_proj = SparseLinear(d, d)
246
+ self.drop = nn.Dropout(do)
247
+ self.register_buffer("mask", torch.tril(torch.ones(bs,bs)).view(1,1,bs,bs))
248
+
249
+ def forward(self, x):
250
+ B,T,C = x.shape
251
+ q,k,v = self.c_attn(x).split(C, 2)
252
+ q = q.view(B,T,self.nh,self.hd).transpose(1,2)
253
+ k = k.view(B,T,self.nh,self.hd).transpose(1,2)
254
+ v = v.view(B,T,self.nh,self.hd).transpose(1,2)
255
+ a = (q @ k.transpose(-2,-1)) / math.sqrt(self.hd)
256
+ a = a.masked_fill(self.mask[:,:,:T,:T]==0, float("-inf"))
257
+ a = self.drop(F.softmax(a, dim=-1))
258
+ return self.c_proj((a @ v).transpose(1,2).contiguous().view(B,T,C))
259
+
260
+ class FFN(nn.Module):
261
+ def __init__(self, d, do, ffn_mult=4):
262
+ super().__init__()
263
+ self.c_fc = SparseLinear(d, ffn_mult * d)
264
+ self.c_proj = SparseLinear(ffn_mult * d, d)
265
+ self.drop = nn.Dropout(do)
266
+ def forward(self, x):
267
+ return self.drop(self.c_proj(F.gelu(self.c_fc(x))))
268
+
269
+ class Block(nn.Module):
270
+ def __init__(self, d, nh, bs, do, ffn_mult=4):
271
+ super().__init__()
272
+ self.ln1 = nn.LayerNorm(d); self.attn = Attn(d, nh, bs, do)
273
+ self.ln2 = nn.LayerNorm(d); self.mlp = FFN(d, do, ffn_mult)
274
+ def forward(self, x):
275
+ x = x + self.attn(self.ln1(x))
276
+ return x + self.mlp(self.ln2(x))
277
+
278
+ class GPT(nn.Module):
279
+ def __init__(self, V, bs, nl, nh, d, do, ffn_mult=4):
280
+ super().__init__()
281
+ self.te = nn.Embedding(V, d); self.pe = nn.Embedding(bs, d)
282
+ self.blocks = nn.Sequential(*[Block(d, nh, bs, do, ffn_mult) for _ in range(nl)])
283
+ self.ln = nn.LayerNorm(d); self.head = nn.Linear(d, V)
284
+ def forward(self, idx, tgt=None):
285
+ B,T = idx.shape
286
+ x = self.te(idx) + self.pe(torch.arange(T, device=idx.device))[None]
287
+ lo = self.head(self.ln(self.blocks(x)))
288
+ loss = F.cross_entropy(lo.view(-1, lo.size(-1)), tgt.view(-1)) if tgt is not None else None
289
+ return lo, loss
290
+ def nparams(self): return sum(p.numel() for p in self.parameters())
291
+
292
+ def get_sparse_linears(m): return [x for x in m.modules() if isinstance(x, SparseLinear)]
293
+
294
+ # ═══════════════════════════════════════════════════════════════
295
+ # DATA
296
+ # ═══════════════════════════════════════════════════════════════
297
+
298
+ class Corpus:
299
+ """Uses tiktoken GPT-2 BPE on Tiny Shakespeare if available, else char-level synthetic."""
300
+ _inst = None
301
+ @classmethod
302
+ def get(cls, bs, dev):
303
+ if cls._inst is None or cls._inst.block_size != bs:
304
+ cls._inst = cls(bs, dev)
305
+ return cls._inst
306
+
307
+ def __init__(self, block_size, device):
308
+ self.block_size, self.device = block_size, device
309
+ import urllib.request
310
+ p = "input.txt"
311
+ if not os.path.exists(p):
312
+ urllib.request.urlretrieve("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", p)
313
+ text = open(p).read()
314
+ if HAS_TIKTOKEN:
315
+ enc = tiktoken.get_encoding("gpt2")
316
+ tokens = enc.encode(text)
317
+ self.vocab_size = enc.n_vocab
318
+ else:
319
+ chars = sorted(set(text))
320
+ stoi = {c:i for i,c in enumerate(chars)}
321
+ tokens = [stoi[c] for c in text]
322
+ self.vocab_size = len(chars)
323
+ data = torch.tensor(tokens, dtype=torch.long)
324
+ si = int(0.9 * len(data))
325
+ self.train_data, self.val_data = data[:si], data[si:]
326
+ print(f"Corpus: V={self.vocab_size}, train={len(self.train_data):,}, val={len(self.val_data):,}")
327
+
328
+ def get_batch(self, split, bs, gen=None):
329
+ d = self.train_data if split == "train" else self.val_data
330
+ ix = torch.randint(len(d)-self.block_size-1, (bs,), generator=gen)
331
+ x = torch.stack([d[i:i+self.block_size] for i in ix])
332
+ y = torch.stack([d[i+1:i+self.block_size+1] for i in ix])
333
+ return x.to(self.device), y.to(self.device)
334
+
335
+ def make_gen(s):
336
+ g = torch.Generator(device="cpu"); g.manual_seed(s); return g
337
+
338
+ # ═══════════════════════════════════════════════════════════════
339
+ # SCHEDULER (from v18, with KNN)
340
+ # ═══════════════════════════════════════════════════════════════
341
+
342
+ class ChunkScheduler:
343
+ def __init__(self, model, policy, frac, cs, dev, beta=0.95, knn_k=3,
344
+ sim_hist=128, min_sim_hist=8):
345
+ self.policy, self.frac, self.cs, self.dev = policy, frac, cs, dev
346
+ self.beta, self.knn_k = beta, knn_k
347
+ self.sim_hist, self.min_sim_hist = sim_hist, min_sim_hist
348
+ self.linears = get_sparse_linears(model)
349
+ self.m2ids, self.m2loc = {}, {}
350
+ off = 0
351
+ for m in self.linears:
352
+ m.chunk_size = cs
353
+ nc = m.out_features // cs
354
+ assert m.out_features % cs == 0
355
+ self.m2ids[m] = torch.arange(off, off+nc, device=dev)
356
+ self.m2loc[m] = torch.arange(nc, device=dev)
357
+ off += nc
358
+ self.nc = off
359
+ self.ema = torch.zeros(self.nc, device=dev)
360
+ self.active = torch.zeros(self.nc, dtype=torch.bool, device=dev)
361
+ self.mass_history = []
362
+ self.similarity = None
363
+ self.scores = torch.zeros(self.nc, device=dev)
364
+
365
+ def get_frac(self, step, wu, an):
366
+ if step < wu: return 1.0
367
+ if an > 0 and step < wu + an:
368
+ p = (step - wu) / an
369
+ return self.frac + (1-self.frac) * 0.5 * (1 + math.cos(math.pi * p))
370
+ return self.frac
371
+
372
+ def choose(self, step, wu, an):
373
+ f = self.get_frac(step, wu, an)
374
+ if f >= 0.999:
375
+ self.active.fill_(True)
376
+ self._install(); return
377
+ k = max(1, int(f * self.nc))
378
+ self.active.fill_(False)
379
+ if self.policy == "random":
380
+ idx = torch.randperm(self.nc, device=self.dev)[:k]
381
+ elif self.policy == "ema":
382
+ idx = torch.topk(self.ema + 1e-9*torch.rand_like(self.ema), k=k).indices
383
+ elif self.policy == "knn":
384
+ base = self.scores if self.scores.sum() > 1e-12 else self.ema
385
+ idx = torch.topk(base + 1e-9*torch.rand_like(base), k=k).indices
386
+ else:
387
+ raise ValueError(self.policy)
388
+ self.active[idx] = True
389
+ self._install()
390
+
391
+ def _install(self):
392
+ for m, gids in self.m2ids.items():
393
+ m.active_chunks = self.m2loc[m][self.active[gids]]
394
+
395
+ @torch.no_grad()
396
+ def update(self, step, wu):
397
+ cur = torch.zeros_like(self.ema)
398
+ for m, ids in self.m2ids.items():
399
+ if m.weight.grad is None: continue
400
+ s = m.weight.grad.square().view(len(ids), self.cs, -1).sum((1,2))
401
+ if m.bias is not None and m.bias.grad is not None:
402
+ s += m.bias.grad.square().view(len(ids), self.cs).sum(1)
403
+ cur[ids] = torch.sqrt(s + 1e-30)
404
+ obs = self.active
405
+ new = obs & (self.ema == 0)
406
+ old = obs & ~new
407
+ self.ema[new] = cur[new]
408
+ self.ema[old] = self.beta*self.ema[old] + (1-self.beta)*cur[old]
409
+ # KNN similarity building during warmup
410
+ if step < wu:
411
+ self.mass_history.append(cur.clone())
412
+ if len(self.mass_history) > self.sim_hist:
413
+ self.mass_history = self.mass_history[-self.sim_hist:]
414
+ if len(self.mass_history) >= self.min_sim_hist:
415
+ self.similarity = self._build_sim()
416
+ if self.policy == "knn":
417
+ self.scores = self._knn_scores(self.active, cur)
418
+ else:
419
+ self.scores = self.ema.clone()
420
+ return cur
421
+
422
+ def _build_sim(self):
423
+ H = torch.stack(self.mass_history)
424
+ H = (H - H.mean(0, keepdim=True)) / (H.std(0, keepdim=True) + 1e-6)
425
+ S = torch.clamp((H.T @ H) / max(1, H.shape[0]-1), min=0)
426
+ S.fill_diagonal_(0)
427
+ ok = torch.zeros_like(S, dtype=torch.bool)
428
+ for _, ids in self.m2ids.items():
429
+ ok[ids[:,None], ids[None,:]] = True
430
+ return torch.where(ok, S, torch.zeros_like(S))
431
+
432
+ def _knn_scores(self, active_mask, cur):
433
+ if self.similarity is None: return self.ema.clone()
434
+ sc = self.ema.clone()
435
+ sc[active_mask] = cur[active_mask]
436
+ aidx = active_mask.nonzero(as_tuple=False).flatten()
437
+ iidx = (~active_mask).nonzero(as_tuple=False).flatten()
438
+ if aidx.numel() == 0: return sc
439
+ S = self.similarity
440
+ for i in iidx.tolist():
441
+ w = S[i, aidx]
442
+ if w.sum() <= 1e-12: continue
443
+ kk = min(self.knn_k, w.numel())
444
+ top = torch.topk(w, k=kk)
445
+ sc[i] = (top.values * cur[aidx[top.indices]]).sum() / (top.values.sum() + 1e-12)
446
+ return sc
447
+
448
+ @torch.no_grad()
449
+ def oracle_scores(self):
450
+ """Compute dense gradient magnitudes per chunk (requires dense grads already computed)."""
451
+ sc = torch.zeros(self.nc, device=self.dev)
452
+ for m, ids in self.m2ids.items():
453
+ if m.weight.grad is None: continue
454
+ s = m.weight.grad.square().view(len(ids), self.cs, -1).sum((1,2))
455
+ if m.bias is not None and m.bias.grad is not None:
456
+ s += m.bias.grad.square().view(len(ids), self.cs).sum(1)
457
+ sc[ids] = torch.sqrt(s + 1e-30)
458
+ return sc
459
+
460
+ def measure_overlap(self, k):
461
+ """Jaccard and recall of current active vs oracle top-k."""
462
+ oracle = set(torch.topk(self.oracle_scores(), k=k).indices.tolist())
463
+ pred = set(self.active.nonzero(as_tuple=True)[0].tolist())
464
+ if not oracle or not pred: return 0., 0.
465
+ inter = oracle & pred
466
+ return len(inter)/len(oracle|pred), len(inter)/len(oracle)
467
+
468
+ # ══════════════════════��════════════════════════════════════════
469
+ # CHUNKED ADAM WITH PHANTOM/FROZEN MODES
470
+ # ═══════════════════════════════════════════════════════════════
471
+
472
+ class ChunkedAdam:
473
+ """
474
+ Adam with two modes for inactive chunks:
475
+ phantom: standard β€” m,v decay even on zero grad (default, original behavior)
476
+ frozen: m,v state completely frozen for inactive chunks
477
+ """
478
+ def __init__(self, model, lr=3e-4, cs=64, momentum_mode="phantom"):
479
+ self.model, self.lr, self.cs = model, lr, cs
480
+ self.momentum_mode = momentum_mode # "phantom" or "frozen"
481
+ self.state = {}
482
+ self.p2m = {}
483
+ for m in get_sparse_linears(model):
484
+ if m.weight is not None: self.p2m[m.weight] = m
485
+ if m.bias is not None: self.p2m[m.bias] = m
486
+
487
+ def zero_grad(self):
488
+ for p in self.model.parameters(): p.grad = None
489
+
490
+ @torch.no_grad()
491
+ def step(self):
492
+ for p in self.model.parameters():
493
+ if p.grad is None: continue
494
+ if p not in self.state:
495
+ self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
496
+ m, v = self.state[p]["m"], self.state[p]["v"]
497
+ sm = self.p2m.get(p)
498
+ ac = getattr(sm, 'active_chunks', None) if sm else None
499
+
500
+ if ac is None:
501
+ # Dense parameter (LN, embeddings, lm_head) β€” always full update
502
+ m.mul_(0.9).add_(p.grad, alpha=0.1)
503
+ v.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
504
+ p.sub_(m / (torch.sqrt(v) + 1e-8), alpha=self.lr)
505
+ else:
506
+ if self.momentum_mode == "phantom":
507
+ # PHANTOM: update ALL chunks' moments, but only active get real gradients.
508
+ # Inactive chunks see grad=0, so m decays and v decays.
509
+ # This is the original behavior.
510
+ m.mul_(0.9).add_(p.grad, alpha=0.1)
511
+ v.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
512
+ # But only update weights for active chunks
513
+ for c in ac.tolist():
514
+ s, e = c*self.cs, (c+1)*self.cs
515
+ p.data[s:e].sub_(m[s:e] / (torch.sqrt(v[s:e]) + 1e-8), alpha=self.lr)
516
+ elif self.momentum_mode == "frozen":
517
+ # FROZEN: only touch m,v,p for active chunks. Inactive state is untouched.
518
+ for c in ac.tolist():
519
+ s, e = c*self.cs, (c+1)*self.cs
520
+ g = p.grad[s:e]
521
+ m[s:e].mul_(0.9).add_(g, alpha=0.1)
522
+ v[s:e].mul_(0.999).addcmul_(g, g, value=0.001)
523
+ p.data[s:e].sub_(m[s:e] / (torch.sqrt(v[s:e]) + 1e-8), alpha=self.lr)
524
+
525
+ # ═══════════════════════════════════════════════════════════════
526
+ # EVALUATION
527
+ # ═══════════════════════════════════════════════════════════════
528
+
529
+ @torch.no_grad()
530
+ def evaluate(model, corpus, bs, n=20, seed=9999):
531
+ model.eval()
532
+ losses = []
533
+ for i in range(n):
534
+ _, l = model(*corpus.get_batch("val", bs, make_gen(seed+i)))
535
+ losses.append(l.item())
536
+ model.train()
537
+ avg = sum(losses)/len(losses)
538
+ return avg, math.exp(min(avg, 20))
539
+
540
+ # ═══════════════════════════════════════════════════════════════
541
+ # SINGLE TRAINING RUN
542
+ # ═══════════════════════════════════════════════════════════════
543
+
544
+ def run(policy, bwd_mode, steps, bs, block_size, nl, nh, d, cs,
545
+ active_frac, wu, an, lr, device, seed, backend="triton",
546
+ momentum_mode="phantom", ffn_mult=4,
547
+ measure_oracle=False, oracle_interval=50):
548
+ """Run one training config. Returns dict of results."""
549
+ torch.manual_seed(seed)
550
+ if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
551
+ random.seed(seed)
552
+
553
+ corpus = Corpus.get(block_size, device)
554
+ model = GPT(corpus.vocab_size, block_size, nl, nh, d, 0.1, ffn_mult).to(device)
555
+ for m in get_sparse_linears(model):
556
+ m.chunk_size = cs
557
+ m.backend = backend
558
+
559
+ is_dense = (policy == "dense")
560
+ sched = None if is_dense else ChunkScheduler(model, policy, active_frac, cs, device)
561
+ opt = ChunkedAdam(model, lr=lr, cs=cs, momentum_mode=momentum_mode)
562
+
563
+ np_ = model.nparams()
564
+ overlaps = []
565
+
566
+ torch.cuda.synchronize() if device == "cuda" else None
567
+ t0 = time.perf_counter()
568
+
569
+ for step in range(steps):
570
+ x, y = corpus.get_batch("train", bs, make_gen(step))
571
+
572
+ if is_dense:
573
+ for m in get_sparse_linears(model):
574
+ m.sparse_enabled = False; m.active_chunks = None
575
+ else:
576
+ sched.choose(step, wu, an)
577
+ for m in get_sparse_linears(model):
578
+ m.sparse_enabled = True
579
+ m.sparse_dx = (bwd_mode == "sparse_dX")
580
+
581
+ opt.zero_grad()
582
+ _, loss = model(x, y)
583
+ loss.backward()
584
+
585
+ if sched:
586
+ sched.update(step, wu)
587
+
588
+ # Oracle overlap measurement
589
+ if measure_oracle and step % oracle_interval == 0 and step >= wu + an:
590
+ saved = {p: p.grad.clone() for p in model.parameters() if p.grad is not None}
591
+ for m in get_sparse_linears(model): m.sparse_enabled = False
592
+ for p in model.parameters(): p.grad = None
593
+ _, lo = model(x, y); lo.backward()
594
+ k = max(1, int(active_frac * sched.nc))
595
+ j, r = sched.measure_overlap(k)
596
+ overlaps.append((step, j, r))
597
+ for p in model.parameters():
598
+ if p in saved: p.grad = saved[p]
599
+ for m in get_sparse_linears(model): m.sparse_enabled = True
600
+
601
+ opt.step()
602
+
603
+ if step % 200 == 0:
604
+ print(f" step {step}/{steps} loss={loss.item():.4f}")
605
+
606
+ torch.cuda.synchronize() if device == "cuda" else None
607
+ wall = time.perf_counter() - t0
608
+
609
+ for m in get_sparse_linears(model): m.sparse_enabled = False
610
+ vl, vp = evaluate(model, corpus, bs, n=30)
611
+
612
+ del model; torch.cuda.empty_cache() if device == "cuda" else None
613
+
614
+ return {
615
+ "val_loss": vl, "val_ppl": vp, "wall_time": wall,
616
+ "ms_per_step": 1000*wall/steps, "n_params": np_,
617
+ "train_loss_final": loss.item(), "overlaps": overlaps,
618
+ }
619
+
620
+ def run_seeds(cfg, seeds):
621
+ results = []
622
+ for s in seeds:
623
+ cfg["seed"] = s
624
+ results.append(run(**cfg))
625
+ vls = [r["val_loss"] for r in results]
626
+ ml = sum(vls)/len(vls)
627
+ sl = (sum((x-ml)**2 for x in vls)/max(1,len(vls)-1))**0.5
628
+ return {"mean_loss": ml, "std_loss": sl, "results": results,
629
+ "mean_ms": sum(r["ms_per_step"] for r in results)/len(results)}
630
+
631
+ # ═══════════════════════════════════════════════════════════════
632
+ # EXPERIMENT 1: PHANTOM MOMENTUM ABLATION
633
+ # ═══════════════════════════════════════════════════════════════
634
+
635
+ def exp_phantom_momentum(device, steps, seeds, d, nl, nh, bs, block_size, cs, af, wu, an, lr, backend):
636
+ print("\n" + "="*80)
637
+ print("EXPERIMENT 1: Phantom Momentum Ablation")
638
+ print("="*80)
639
+
640
+ base = dict(bwd_mode="full_dX", steps=steps, bs=bs, block_size=block_size,
641
+ nl=nl, nh=nh, d=d, cs=cs, active_frac=af, wu=wu, an=an,
642
+ lr=lr, device=device, backend=backend)
643
+
644
+ configs = [
645
+ ("dense", "dense", "phantom"),
646
+ ("ema+phantom", "ema", "phantom"),
647
+ ("ema+frozen", "ema", "frozen"),
648
+ ("knn+phantom", "knn", "phantom"),
649
+ ("knn+frozen", "knn", "frozen"),
650
+ ("random+phantom", "random", "phantom"),
651
+ ("random+frozen", "random", "frozen"),
652
+ ]
653
+
654
+ results = {}
655
+ for name, policy, mm in configs:
656
+ print(f"\n--- {name} ---")
657
+ cfg = {**base, "policy": policy, "momentum_mode": mm}
658
+ results[name] = run_seeds(cfg, seeds)
659
+
660
+ print(f"\n{'Method':<22} | {'Val Loss':>18} | {'ms/step':>10}")
661
+ print("-"*55)
662
+ for name, _, _ in configs:
663
+ r = results[name]
664
+ print(f"{name:<22} | {r['mean_loss']:.4f} Β± {r['std_loss']:.4f} | {r['mean_ms']:>9.1f}")
665
+
666
+ return results
667
+
668
+ # ═══════════════════════════════════════════════════════════════
669
+ # EXPERIMENT 2: COMPUTE-MATCHED BASELINES
670
+ # ═══════════════════════════════════════════════════════════════
671
+
672
+ def exp_compute_matched(device, steps, seeds, d, nl, nh, bs, block_size, cs, af, wu, an, lr, backend):
673
+ print("\n" + "="*80)
674
+ print("EXPERIMENT 2: Compute-Matched Baselines")
675
+ print("="*80)
676
+
677
+ base = dict(bwd_mode="full_dX", steps=steps, bs=bs, block_size=block_size,
678
+ nl=nl, nh=nh, d=d, cs=cs, active_frac=af, wu=wu, an=an,
679
+ lr=lr, device=device, backend=backend, momentum_mode="phantom")
680
+
681
+ # 1. Sparse reference
682
+ print("\n--- Sparse (EMA, reference) ---")
683
+ sparse_r = run_seeds({**base, "policy": "ema"}, seeds)
684
+
685
+ # 2. Dense at same steps
686
+ print("\n--- Dense (same steps) ---")
687
+ dense_same = run_seeds({**base, "policy": "dense"}, seeds)
688
+
689
+ # 3. Dense at compute-matched steps
690
+ # Sparse does ~70% of dense FLOPs (fwd dense + dX dense + dW at 10%)
691
+ ratio = (1.0 + 1.0 + af) / 3.0
692
+ matched_steps = int(steps * ratio)
693
+ print(f"\n--- Dense (compute-matched, {matched_steps} steps) ---")
694
+ dense_matched = run_seeds({**base, "policy": "dense", "steps": matched_steps}, seeds)
695
+
696
+ # 4. Natively smaller dense model: FFN multiplier = 4 * af = 0.4 (rounded)
697
+ # This gives a model with ~10% of the FFN capacity
698
+ small_ffn_mult = max(1, round(4 * af)) # 4*0.1 = 0.4, round to 1
699
+ print(f"\n--- Small dense (ffn_mult={small_ffn_mult}, capacity-matched) ---")
700
+ dense_small = run_seeds({**base, "policy": "dense", "ffn_mult": small_ffn_mult}, seeds)
701
+
702
+ results = {
703
+ "sparse_ema": sparse_r,
704
+ "dense_same_steps": dense_same,
705
+ f"dense_matched_{matched_steps}steps": dense_matched,
706
+ f"dense_small_ffn{small_ffn_mult}": dense_small,
707
+ }
708
+
709
+ print(f"\n{'Method':<35} | {'Steps':>6} | {'Params':>8} | {'Val Loss':>18} | {'ms/step':>10}")
710
+ print("-"*90)
711
+ for name, r in results.items():
712
+ np_ = r["results"][0]["n_params"]
713
+ st = r["results"][0].get("steps", steps) if "steps" in name else steps
714
+ # read actual steps from config β€” approximate
715
+ print(f"{name:<35} | {st if 'matched' not in name else matched_steps:>6} | {np_/1e6:>7.1f}M | {r['mean_loss']:.4f} Β± {r['std_loss']:.4f} | {r['mean_ms']:>9.1f}")
716
+
717
+ return results
718
+
719
+ # ═══════════════════════════════════════════════════════════════
720
+ # EXPERIMENT 3: PREDICTOR ACCURACY (EMA vs KNN vs Oracle)
721
+ # ═══════════════════════════════════════════════════════════════
722
+
723
+ def exp_predictor_accuracy(device, steps, seeds, d, nl, nh, bs, block_size, cs, af, wu, an, lr, backend):
724
+ print("\n" + "="*80)
725
+ print("EXPERIMENT 3: Predictor Accuracy (EMA vs KNN vs Oracle)")
726
+ print("="*80)
727
+
728
+ base = dict(bwd_mode="full_dX", steps=steps, bs=bs, block_size=block_size,
729
+ nl=nl, nh=nh, d=d, cs=cs, active_frac=af, wu=wu, an=an,
730
+ lr=lr, device=device, backend=backend, momentum_mode="phantom",
731
+ measure_oracle=True, oracle_interval=25)
732
+
733
+ results = {}
734
+ for policy in ["ema", "knn", "random"]:
735
+ print(f"\n--- {policy} ---")
736
+ results[policy] = run_seeds({**base, "policy": policy}, seeds)
737
+
738
+ # Aggregate overlaps
739
+ for policy in ["ema", "knn", "random"]:
740
+ print(f"\n{policy.upper()} predictor overlap:")
741
+ print(f" {'Step':>6} | {'Jaccard':>10} | {'Recall':>10}")
742
+ sd = defaultdict(lambda: {"j": [], "r": []})
743
+ for res in results[policy]["results"]:
744
+ for s, j, r in res["overlaps"]:
745
+ sd[s]["j"].append(j); sd[s]["r"].append(r)
746
+ for s in sorted(sd):
747
+ mj = sum(sd[s]["j"])/len(sd[s]["j"])
748
+ mr = sum(sd[s]["r"])/len(sd[s]["r"])
749
+ print(f" {s:>6} | {mj:>10.4f} | {mr:>10.4f}")
750
+
751
+ print(f"\n{'Policy':<10} | {'Val Loss':>18} | {'ms/step':>10}")
752
+ print("-"*45)
753
+ for p in ["ema", "knn", "random"]:
754
+ r = results[p]
755
+ print(f"{p:<10} | {r['mean_loss']:.4f} Β± {r['std_loss']:.4f} | {r['mean_ms']:>9.1f}")
756
+
757
+ return results
758
+
759
+ # ═══════════════════════════════════════════════════════════════
760
+ # MAIN
761
+ # ═══════════════════════════════════════════════════════════════
762
+
763
+ ALL_EXPS = {
764
+ "phantom_momentum": exp_phantom_momentum,
765
+ "compute_matched": exp_compute_matched,
766
+ "predictor_accuracy": exp_predictor_accuracy,
767
+ }
768
+
769
+ def main():
770
+ p = argparse.ArgumentParser()
771
+ p.add_argument("--experiment", default="all", choices=list(ALL_EXPS)+["all"])
772
+ p.add_argument("--device", default="cuda")
773
+ p.add_argument("--steps", type=int, default=1000)
774
+ p.add_argument("--seeds", default="42,123,456")
775
+ p.add_argument("--n_embd", type=int, default=1024)
776
+ p.add_argument("--n_layer", type=int, default=4)
777
+ p.add_argument("--n_head", type=int, default=8)
778
+ p.add_argument("--batch_size", type=int, default=8)
779
+ p.add_argument("--block_size", type=int, default=256)
780
+ p.add_argument("--chunk_size", type=int, default=64)
781
+ p.add_argument("--active_fraction", type=float, default=0.10)
782
+ p.add_argument("--warmup_steps", type=int, default=50)
783
+ p.add_argument("--anneal_steps", type=int, default=200)
784
+ p.add_argument("--lr", type=float, default=3e-4)
785
+ p.add_argument("--backend", default="triton", choices=["triton", "torch"])
786
+ p.add_argument("--output_dir", default="results")
787
+ args = p.parse_args()
788
+
789
+ seeds = [int(s) for s in args.seeds.split(",")]
790
+ os.makedirs(args.output_dir, exist_ok=True)
791
+
792
+ if args.device == "cuda" and torch.cuda.is_available():
793
+ print(f"GPU: {torch.cuda.get_device_name()} | VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")
794
+ print(f"Config: d={args.n_embd} nl={args.n_layer} nh={args.n_head} steps={args.steps} seeds={seeds}")
795
+ print(f" cs={args.chunk_size} af={args.active_fraction} backend={args.backend}")
796
+
797
+ shared = dict(device=args.device, steps=args.steps, seeds=seeds,
798
+ d=args.n_embd, nl=args.n_layer, nh=args.n_head,
799
+ bs=args.batch_size, block_size=args.block_size,
800
+ cs=args.chunk_size, af=args.active_fraction,
801
+ wu=args.warmup_steps, an=args.anneal_steps,
802
+ lr=args.lr, backend=args.backend)
803
+
804
+ exps = ALL_EXPS if args.experiment == "all" else {args.experiment: ALL_EXPS[args.experiment]}
805
+ t0 = time.time()
806
+
807
+ for name, fn in exps.items():
808
+ print(f"\n{'#'*80}\n# {name} ({(time.time()-t0)/60:.1f}m elapsed)\n{'#'*80}")
809
+ sys.stdout.flush()
810
+ result = fn(**shared)
811
+
812
+ def ser(o):
813
+ if isinstance(o, dict): return {str(k): ser(v) for k,v in o.items()}
814
+ if isinstance(o, list): return [ser(x) for x in o]
815
+ return o
816
+
817
+ with open(os.path.join(args.output_dir, f"{name}.json"), "w") as f:
818
+ json.dump(ser(result), f, indent=2, default=str)
819
+ print(f"βœ“ {name} saved to {args.output_dir}/{name}.json")
820
+
821
+ print(f"\n{'='*80}\nALL COMPLETE in {(time.time()-t0)/60:.1f} minutes\n{'='*80}")
822
+
823
+ if __name__ == "__main__":
824
+ main()