Upload ablations.py with huggingface_hub
Browse files- ablations.py +824 -1
ablations.py
CHANGED
|
@@ -1 +1,824 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Sparse Transformer: Definitive Ablation Suite
|
| 4 |
+
|
| 5 |
+
Builds on v18_fast_knn_triton.py. Addresses all three structural gaps
|
| 6 |
+
identified in the critique:
|
| 7 |
+
|
| 8 |
+
1. PHANTOM MOMENTUM ABLATION
|
| 9 |
+
- "phantom": standard Adam β inactive chunks' moments decay on zero grad (default)
|
| 10 |
+
- "frozen": inactive chunks' Adam state (m, v) is completely frozen
|
| 11 |
+
Compare across all schedulers to isolate whether convergence is driven
|
| 12 |
+
by the chunking algorithm or by phantom momentum acting as regularization.
|
| 13 |
+
|
| 14 |
+
2. COMPUTE-MATCHED BASELINES
|
| 15 |
+
- Dense at same steps (standard comparison)
|
| 16 |
+
- Dense at fewer steps matching sparse FLOPs
|
| 17 |
+
- Natively smaller dense model matching sparse active capacity
|
| 18 |
+
|
| 19 |
+
3. UNIFIED HARDWARE
|
| 20 |
+
Everything on CUDA (A10G). Single hardware stack.
|
| 21 |
+
|
| 22 |
+
Plus: KNN vs EMA vs Random vs Oracle predictor comparison with proper
|
| 23 |
+
oracle overlap measurement.
|
| 24 |
+
|
| 25 |
+
Run:
|
| 26 |
+
python ablations.py --device cuda --steps 1000 --n_embd 1024 --experiment all
|
| 27 |
+
python ablations.py --device cuda --experiment phantom_momentum
|
| 28 |
+
python ablations.py --device cuda --experiment compute_matched
|
| 29 |
+
python ablations.py --device cuda --experiment predictor_accuracy
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
from __future__ import annotations
|
| 33 |
+
|
| 34 |
+
import argparse
|
| 35 |
+
import json
|
| 36 |
+
import math
|
| 37 |
+
import os
|
| 38 |
+
import random
|
| 39 |
+
import sys
|
| 40 |
+
import time
|
| 41 |
+
from collections import defaultdict
|
| 42 |
+
from typing import Dict, List, Literal, Optional, Tuple
|
| 43 |
+
|
| 44 |
+
import torch
|
| 45 |
+
import torch.nn as nn
|
| 46 |
+
import torch.nn.functional as F
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
import triton
|
| 50 |
+
import triton.language as tl
|
| 51 |
+
HAS_TRITON = True
|
| 52 |
+
except ImportError:
|
| 53 |
+
HAS_TRITON = False
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
import tiktoken
|
| 57 |
+
HAS_TIKTOKEN = True
|
| 58 |
+
except ImportError:
|
| 59 |
+
HAS_TIKTOKEN = False
|
| 60 |
+
|
| 61 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 62 |
+
# TRITON KERNELS (from v18_triton, no autotune, block_ptr)
|
| 63 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 64 |
+
|
| 65 |
+
if HAS_TRITON:
|
| 66 |
+
@triton.jit
|
| 67 |
+
def _sparse_bwd_dW_db_kernel(
|
| 68 |
+
X_ptr, dY_ptr, dW_ptr, dB_ptr, chunk_ids_ptr,
|
| 69 |
+
M: tl.constexpr, d_in: tl.constexpr, d_out: tl.constexpr,
|
| 70 |
+
num_active: tl.constexpr,
|
| 71 |
+
stride_xm: tl.constexpr, stride_xk: tl.constexpr,
|
| 72 |
+
stride_dym: tl.constexpr, stride_dyn: tl.constexpr,
|
| 73 |
+
stride_dwn: tl.constexpr, stride_dwk: tl.constexpr,
|
| 74 |
+
HAS_BIAS: tl.constexpr,
|
| 75 |
+
CS: tl.constexpr, BK: tl.constexpr, BM: tl.constexpr,
|
| 76 |
+
):
|
| 77 |
+
cli = tl.program_id(0)
|
| 78 |
+
kbi = tl.program_id(1)
|
| 79 |
+
cidx = tl.load(chunk_ids_ptr + cli)
|
| 80 |
+
cs0 = cidx * CS
|
| 81 |
+
ko = kbi * BK
|
| 82 |
+
|
| 83 |
+
dy_bp = tl.make_block_ptr(dY_ptr, (d_out, M), (stride_dyn, stride_dym),
|
| 84 |
+
(cs0, 0), (CS, BM), (1, 0))
|
| 85 |
+
x_bp = tl.make_block_ptr(X_ptr, (M, d_in), (stride_xm, stride_xk),
|
| 86 |
+
(0, ko), (BM, BK), (1, 0))
|
| 87 |
+
|
| 88 |
+
acc = tl.zeros((CS, BK), dtype=tl.float32)
|
| 89 |
+
do_bias = HAS_BIAS and (kbi == 0)
|
| 90 |
+
acc_b = tl.zeros((CS,), dtype=tl.float32)
|
| 91 |
+
|
| 92 |
+
for _ in range(0, M, BM):
|
| 93 |
+
dy_t = tl.load(dy_bp, boundary_check=(0, 1))
|
| 94 |
+
x = tl.load(x_bp, boundary_check=(0, 1))
|
| 95 |
+
acc = tl.dot(dy_t, x, acc=acc)
|
| 96 |
+
if do_bias:
|
| 97 |
+
acc_b += tl.sum(dy_t, axis=1)
|
| 98 |
+
dy_bp = tl.advance(dy_bp, (0, BM))
|
| 99 |
+
x_bp = tl.advance(x_bp, (BM, 0))
|
| 100 |
+
|
| 101 |
+
dw_bp = tl.make_block_ptr(dW_ptr, (d_out, d_in), (stride_dwn, stride_dwk),
|
| 102 |
+
(cs0, ko), (CS, BK), (1, 0))
|
| 103 |
+
tl.store(dw_bp, acc.to(dW_ptr.dtype.element_ty), boundary_check=(0, 1))
|
| 104 |
+
|
| 105 |
+
if do_bias:
|
| 106 |
+
rn = cs0 + tl.arange(0, CS)
|
| 107 |
+
tl.store(dB_ptr + rn, acc_b.to(dB_ptr.dtype.element_ty), mask=rn < d_out)
|
| 108 |
+
|
| 109 |
+
@triton.jit
|
| 110 |
+
def _sparse_bwd_dX_kernel(
|
| 111 |
+
dY_ptr, W_ptr, dX_ptr, chunk_ids_ptr,
|
| 112 |
+
M: tl.constexpr, d_in: tl.constexpr, d_out: tl.constexpr,
|
| 113 |
+
num_active: tl.constexpr,
|
| 114 |
+
stride_dym: tl.constexpr, stride_dyn: tl.constexpr,
|
| 115 |
+
stride_wn: tl.constexpr, stride_wk: tl.constexpr,
|
| 116 |
+
stride_dxm: tl.constexpr, stride_dxk: tl.constexpr,
|
| 117 |
+
CS: tl.constexpr, BM: tl.constexpr, BK: tl.constexpr,
|
| 118 |
+
):
|
| 119 |
+
pm = tl.program_id(0)
|
| 120 |
+
pk = tl.program_id(1)
|
| 121 |
+
mo = pm * BM
|
| 122 |
+
ko = pk * BK
|
| 123 |
+
acc = tl.zeros((BM, BK), dtype=tl.float32)
|
| 124 |
+
for i in range(0, num_active):
|
| 125 |
+
cidx = tl.load(chunk_ids_ptr + i)
|
| 126 |
+
cs0 = cidx * CS
|
| 127 |
+
dy_bp = tl.make_block_ptr(dY_ptr, (M, d_out), (stride_dym, stride_dyn),
|
| 128 |
+
(mo, cs0), (BM, CS), (1, 0))
|
| 129 |
+
w_bp = tl.make_block_ptr(W_ptr, (d_out, d_in), (stride_wn, stride_wk),
|
| 130 |
+
(cs0, ko), (CS, BK), (1, 0))
|
| 131 |
+
dy = tl.load(dy_bp, boundary_check=(0, 1))
|
| 132 |
+
w = tl.load(w_bp, boundary_check=(0, 1))
|
| 133 |
+
acc = tl.dot(dy, w, acc=acc)
|
| 134 |
+
dx_bp = tl.make_block_ptr(dX_ptr, (M, d_in), (stride_dxm, stride_dxk),
|
| 135 |
+
(mo, ko), (BM, BK), (1, 0))
|
| 136 |
+
tl.store(dx_bp, acc.to(dX_ptr.dtype.element_ty), boundary_check=(0, 1))
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def triton_bwd_dW_db(xf, gyf, active, cs, d_out, has_bias):
|
| 140 |
+
M, d_in = xf.shape
|
| 141 |
+
na = active.numel()
|
| 142 |
+
dW = torch.zeros(d_out, d_in, device=xf.device, dtype=xf.dtype)
|
| 143 |
+
dB = torch.zeros(d_out, device=xf.device, dtype=xf.dtype) if has_bias else None
|
| 144 |
+
if na == 0: return dW, dB
|
| 145 |
+
cids = active.to(torch.int32).contiguous()
|
| 146 |
+
BK, BM = 64, 64
|
| 147 |
+
_sparse_bwd_dW_db_kernel[(na, triton.cdiv(d_in, BK))](
|
| 148 |
+
xf, gyf, dW, dB if has_bias else dW, cids,
|
| 149 |
+
M, d_in, d_out, na,
|
| 150 |
+
xf.stride(0), xf.stride(1), gyf.stride(0), gyf.stride(1),
|
| 151 |
+
dW.stride(0), dW.stride(1),
|
| 152 |
+
HAS_BIAS=has_bias, CS=cs, BK=BK, BM=BM, num_warps=4)
|
| 153 |
+
return dW, dB
|
| 154 |
+
|
| 155 |
+
def triton_bwd_dX(gyf, w, active, cs, M, d_in):
|
| 156 |
+
na = active.numel()
|
| 157 |
+
d_out = gyf.shape[1]
|
| 158 |
+
dX = torch.zeros(M, d_in, device=gyf.device, dtype=gyf.dtype)
|
| 159 |
+
if na == 0: return dX
|
| 160 |
+
cids = active.to(torch.int32).contiguous()
|
| 161 |
+
BM, BK = 64, 64
|
| 162 |
+
_sparse_bwd_dX_kernel[(triton.cdiv(M, BM), triton.cdiv(d_in, BK))](
|
| 163 |
+
gyf, w, dX, cids,
|
| 164 |
+
M, d_in, d_out, na,
|
| 165 |
+
gyf.stride(0), gyf.stride(1), w.stride(0), w.stride(1),
|
| 166 |
+
dX.stride(0), dX.stride(1),
|
| 167 |
+
CS=cs, BM=BM, BK=BK, num_warps=4)
|
| 168 |
+
return dX
|
| 169 |
+
|
| 170 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 171 |
+
# AUTOGRAD
|
| 172 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 173 |
+
|
| 174 |
+
class TritonSparseLinearFn(torch.autograd.Function):
|
| 175 |
+
@staticmethod
|
| 176 |
+
def forward(ctx, x, w, b, active, cs, sparse_dx):
|
| 177 |
+
ctx.save_for_backward(x, w, active)
|
| 178 |
+
ctx.has_bias = b is not None
|
| 179 |
+
ctx.sparse_dx = sparse_dx
|
| 180 |
+
ctx.cs = cs
|
| 181 |
+
return F.linear(x, w, b)
|
| 182 |
+
|
| 183 |
+
@staticmethod
|
| 184 |
+
def backward(ctx, gy):
|
| 185 |
+
x, w, active = ctx.saved_tensors
|
| 186 |
+
cs = ctx.cs
|
| 187 |
+
do, di = w.shape
|
| 188 |
+
xf = x.reshape(-1, di).contiguous()
|
| 189 |
+
gf = gy.reshape(-1, do).contiguous()
|
| 190 |
+
M = xf.shape[0]
|
| 191 |
+
gw, gb = triton_bwd_dW_db(xf, gf, active, cs, do, ctx.has_bias)
|
| 192 |
+
gx = triton_bwd_dX(gf, w.contiguous(), active, cs, M, di) if ctx.sparse_dx else gf @ w
|
| 193 |
+
return gx.reshape(x.shape), gw, gb, None, None, None
|
| 194 |
+
|
| 195 |
+
class PyLoopSparseLinearFn(torch.autograd.Function):
|
| 196 |
+
@staticmethod
|
| 197 |
+
def forward(ctx, x, w, b, active, cs, sparse_dx):
|
| 198 |
+
ctx.save_for_backward(x, w, active)
|
| 199 |
+
ctx.has_bias = b is not None
|
| 200 |
+
ctx.sparse_dx = sparse_dx
|
| 201 |
+
ctx.cs = cs
|
| 202 |
+
return F.linear(x, w, b)
|
| 203 |
+
|
| 204 |
+
@staticmethod
|
| 205 |
+
def backward(ctx, gy):
|
| 206 |
+
x, w, active = ctx.saved_tensors
|
| 207 |
+
cs = ctx.cs
|
| 208 |
+
xf = x.reshape(-1, x.shape[-1])
|
| 209 |
+
gf = gy.reshape(-1, gy.shape[-1])
|
| 210 |
+
gw = torch.zeros_like(w)
|
| 211 |
+
gb = torch.zeros(w.shape[0], device=w.device, dtype=w.dtype) if ctx.has_bias else None
|
| 212 |
+
gx = torch.zeros_like(xf) if ctx.sparse_dx else gf @ w
|
| 213 |
+
for c in active.tolist():
|
| 214 |
+
s, e = c * cs, (c+1) * cs
|
| 215 |
+
sl = gf[:, s:e]
|
| 216 |
+
gw[s:e] = sl.t() @ xf
|
| 217 |
+
if gb is not None: gb[s:e] = sl.sum(0)
|
| 218 |
+
if ctx.sparse_dx: gx += sl @ w[s:e]
|
| 219 |
+
return gx.reshape(x.shape), gw, gb, None, None, None
|
| 220 |
+
|
| 221 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 222 |
+
# MODEL
|
| 223 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 224 |
+
|
| 225 |
+
class SparseLinear(nn.Linear):
|
| 226 |
+
def __init__(self, inf, outf, bias=True):
|
| 227 |
+
super().__init__(inf, outf, bias=bias)
|
| 228 |
+
self.sparse_enabled = False
|
| 229 |
+
self.sparse_dx = False
|
| 230 |
+
self.active_chunks = None
|
| 231 |
+
self.chunk_size = 64
|
| 232 |
+
self.backend = "triton" # "triton" or "torch"
|
| 233 |
+
|
| 234 |
+
def forward(self, x):
|
| 235 |
+
if not self.sparse_enabled or self.active_chunks is None:
|
| 236 |
+
return F.linear(x, self.weight, self.bias)
|
| 237 |
+
fn = TritonSparseLinearFn if (self.backend == "triton" and HAS_TRITON) else PyLoopSparseLinearFn
|
| 238 |
+
return fn.apply(x, self.weight, self.bias, self.active_chunks, self.chunk_size, self.sparse_dx)
|
| 239 |
+
|
| 240 |
+
class Attn(nn.Module):
|
| 241 |
+
def __init__(self, d, nh, bs, do):
|
| 242 |
+
super().__init__()
|
| 243 |
+
self.nh, self.hd = nh, d // nh
|
| 244 |
+
self.c_attn = SparseLinear(d, 3*d)
|
| 245 |
+
self.c_proj = SparseLinear(d, d)
|
| 246 |
+
self.drop = nn.Dropout(do)
|
| 247 |
+
self.register_buffer("mask", torch.tril(torch.ones(bs,bs)).view(1,1,bs,bs))
|
| 248 |
+
|
| 249 |
+
def forward(self, x):
|
| 250 |
+
B,T,C = x.shape
|
| 251 |
+
q,k,v = self.c_attn(x).split(C, 2)
|
| 252 |
+
q = q.view(B,T,self.nh,self.hd).transpose(1,2)
|
| 253 |
+
k = k.view(B,T,self.nh,self.hd).transpose(1,2)
|
| 254 |
+
v = v.view(B,T,self.nh,self.hd).transpose(1,2)
|
| 255 |
+
a = (q @ k.transpose(-2,-1)) / math.sqrt(self.hd)
|
| 256 |
+
a = a.masked_fill(self.mask[:,:,:T,:T]==0, float("-inf"))
|
| 257 |
+
a = self.drop(F.softmax(a, dim=-1))
|
| 258 |
+
return self.c_proj((a @ v).transpose(1,2).contiguous().view(B,T,C))
|
| 259 |
+
|
| 260 |
+
class FFN(nn.Module):
|
| 261 |
+
def __init__(self, d, do, ffn_mult=4):
|
| 262 |
+
super().__init__()
|
| 263 |
+
self.c_fc = SparseLinear(d, ffn_mult * d)
|
| 264 |
+
self.c_proj = SparseLinear(ffn_mult * d, d)
|
| 265 |
+
self.drop = nn.Dropout(do)
|
| 266 |
+
def forward(self, x):
|
| 267 |
+
return self.drop(self.c_proj(F.gelu(self.c_fc(x))))
|
| 268 |
+
|
| 269 |
+
class Block(nn.Module):
|
| 270 |
+
def __init__(self, d, nh, bs, do, ffn_mult=4):
|
| 271 |
+
super().__init__()
|
| 272 |
+
self.ln1 = nn.LayerNorm(d); self.attn = Attn(d, nh, bs, do)
|
| 273 |
+
self.ln2 = nn.LayerNorm(d); self.mlp = FFN(d, do, ffn_mult)
|
| 274 |
+
def forward(self, x):
|
| 275 |
+
x = x + self.attn(self.ln1(x))
|
| 276 |
+
return x + self.mlp(self.ln2(x))
|
| 277 |
+
|
| 278 |
+
class GPT(nn.Module):
|
| 279 |
+
def __init__(self, V, bs, nl, nh, d, do, ffn_mult=4):
|
| 280 |
+
super().__init__()
|
| 281 |
+
self.te = nn.Embedding(V, d); self.pe = nn.Embedding(bs, d)
|
| 282 |
+
self.blocks = nn.Sequential(*[Block(d, nh, bs, do, ffn_mult) for _ in range(nl)])
|
| 283 |
+
self.ln = nn.LayerNorm(d); self.head = nn.Linear(d, V)
|
| 284 |
+
def forward(self, idx, tgt=None):
|
| 285 |
+
B,T = idx.shape
|
| 286 |
+
x = self.te(idx) + self.pe(torch.arange(T, device=idx.device))[None]
|
| 287 |
+
lo = self.head(self.ln(self.blocks(x)))
|
| 288 |
+
loss = F.cross_entropy(lo.view(-1, lo.size(-1)), tgt.view(-1)) if tgt is not None else None
|
| 289 |
+
return lo, loss
|
| 290 |
+
def nparams(self): return sum(p.numel() for p in self.parameters())
|
| 291 |
+
|
| 292 |
+
def get_sparse_linears(m): return [x for x in m.modules() if isinstance(x, SparseLinear)]
|
| 293 |
+
|
| 294 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 295 |
+
# DATA
|
| 296 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 297 |
+
|
| 298 |
+
class Corpus:
|
| 299 |
+
"""Uses tiktoken GPT-2 BPE on Tiny Shakespeare if available, else char-level synthetic."""
|
| 300 |
+
_inst = None
|
| 301 |
+
@classmethod
|
| 302 |
+
def get(cls, bs, dev):
|
| 303 |
+
if cls._inst is None or cls._inst.block_size != bs:
|
| 304 |
+
cls._inst = cls(bs, dev)
|
| 305 |
+
return cls._inst
|
| 306 |
+
|
| 307 |
+
def __init__(self, block_size, device):
|
| 308 |
+
self.block_size, self.device = block_size, device
|
| 309 |
+
import urllib.request
|
| 310 |
+
p = "input.txt"
|
| 311 |
+
if not os.path.exists(p):
|
| 312 |
+
urllib.request.urlretrieve("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", p)
|
| 313 |
+
text = open(p).read()
|
| 314 |
+
if HAS_TIKTOKEN:
|
| 315 |
+
enc = tiktoken.get_encoding("gpt2")
|
| 316 |
+
tokens = enc.encode(text)
|
| 317 |
+
self.vocab_size = enc.n_vocab
|
| 318 |
+
else:
|
| 319 |
+
chars = sorted(set(text))
|
| 320 |
+
stoi = {c:i for i,c in enumerate(chars)}
|
| 321 |
+
tokens = [stoi[c] for c in text]
|
| 322 |
+
self.vocab_size = len(chars)
|
| 323 |
+
data = torch.tensor(tokens, dtype=torch.long)
|
| 324 |
+
si = int(0.9 * len(data))
|
| 325 |
+
self.train_data, self.val_data = data[:si], data[si:]
|
| 326 |
+
print(f"Corpus: V={self.vocab_size}, train={len(self.train_data):,}, val={len(self.val_data):,}")
|
| 327 |
+
|
| 328 |
+
def get_batch(self, split, bs, gen=None):
|
| 329 |
+
d = self.train_data if split == "train" else self.val_data
|
| 330 |
+
ix = torch.randint(len(d)-self.block_size-1, (bs,), generator=gen)
|
| 331 |
+
x = torch.stack([d[i:i+self.block_size] for i in ix])
|
| 332 |
+
y = torch.stack([d[i+1:i+self.block_size+1] for i in ix])
|
| 333 |
+
return x.to(self.device), y.to(self.device)
|
| 334 |
+
|
| 335 |
+
def make_gen(s):
|
| 336 |
+
g = torch.Generator(device="cpu"); g.manual_seed(s); return g
|
| 337 |
+
|
| 338 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 339 |
+
# SCHEDULER (from v18, with KNN)
|
| 340 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 341 |
+
|
| 342 |
+
class ChunkScheduler:
|
| 343 |
+
def __init__(self, model, policy, frac, cs, dev, beta=0.95, knn_k=3,
|
| 344 |
+
sim_hist=128, min_sim_hist=8):
|
| 345 |
+
self.policy, self.frac, self.cs, self.dev = policy, frac, cs, dev
|
| 346 |
+
self.beta, self.knn_k = beta, knn_k
|
| 347 |
+
self.sim_hist, self.min_sim_hist = sim_hist, min_sim_hist
|
| 348 |
+
self.linears = get_sparse_linears(model)
|
| 349 |
+
self.m2ids, self.m2loc = {}, {}
|
| 350 |
+
off = 0
|
| 351 |
+
for m in self.linears:
|
| 352 |
+
m.chunk_size = cs
|
| 353 |
+
nc = m.out_features // cs
|
| 354 |
+
assert m.out_features % cs == 0
|
| 355 |
+
self.m2ids[m] = torch.arange(off, off+nc, device=dev)
|
| 356 |
+
self.m2loc[m] = torch.arange(nc, device=dev)
|
| 357 |
+
off += nc
|
| 358 |
+
self.nc = off
|
| 359 |
+
self.ema = torch.zeros(self.nc, device=dev)
|
| 360 |
+
self.active = torch.zeros(self.nc, dtype=torch.bool, device=dev)
|
| 361 |
+
self.mass_history = []
|
| 362 |
+
self.similarity = None
|
| 363 |
+
self.scores = torch.zeros(self.nc, device=dev)
|
| 364 |
+
|
| 365 |
+
def get_frac(self, step, wu, an):
|
| 366 |
+
if step < wu: return 1.0
|
| 367 |
+
if an > 0 and step < wu + an:
|
| 368 |
+
p = (step - wu) / an
|
| 369 |
+
return self.frac + (1-self.frac) * 0.5 * (1 + math.cos(math.pi * p))
|
| 370 |
+
return self.frac
|
| 371 |
+
|
| 372 |
+
def choose(self, step, wu, an):
|
| 373 |
+
f = self.get_frac(step, wu, an)
|
| 374 |
+
if f >= 0.999:
|
| 375 |
+
self.active.fill_(True)
|
| 376 |
+
self._install(); return
|
| 377 |
+
k = max(1, int(f * self.nc))
|
| 378 |
+
self.active.fill_(False)
|
| 379 |
+
if self.policy == "random":
|
| 380 |
+
idx = torch.randperm(self.nc, device=self.dev)[:k]
|
| 381 |
+
elif self.policy == "ema":
|
| 382 |
+
idx = torch.topk(self.ema + 1e-9*torch.rand_like(self.ema), k=k).indices
|
| 383 |
+
elif self.policy == "knn":
|
| 384 |
+
base = self.scores if self.scores.sum() > 1e-12 else self.ema
|
| 385 |
+
idx = torch.topk(base + 1e-9*torch.rand_like(base), k=k).indices
|
| 386 |
+
else:
|
| 387 |
+
raise ValueError(self.policy)
|
| 388 |
+
self.active[idx] = True
|
| 389 |
+
self._install()
|
| 390 |
+
|
| 391 |
+
def _install(self):
|
| 392 |
+
for m, gids in self.m2ids.items():
|
| 393 |
+
m.active_chunks = self.m2loc[m][self.active[gids]]
|
| 394 |
+
|
| 395 |
+
@torch.no_grad()
|
| 396 |
+
def update(self, step, wu):
|
| 397 |
+
cur = torch.zeros_like(self.ema)
|
| 398 |
+
for m, ids in self.m2ids.items():
|
| 399 |
+
if m.weight.grad is None: continue
|
| 400 |
+
s = m.weight.grad.square().view(len(ids), self.cs, -1).sum((1,2))
|
| 401 |
+
if m.bias is not None and m.bias.grad is not None:
|
| 402 |
+
s += m.bias.grad.square().view(len(ids), self.cs).sum(1)
|
| 403 |
+
cur[ids] = torch.sqrt(s + 1e-30)
|
| 404 |
+
obs = self.active
|
| 405 |
+
new = obs & (self.ema == 0)
|
| 406 |
+
old = obs & ~new
|
| 407 |
+
self.ema[new] = cur[new]
|
| 408 |
+
self.ema[old] = self.beta*self.ema[old] + (1-self.beta)*cur[old]
|
| 409 |
+
# KNN similarity building during warmup
|
| 410 |
+
if step < wu:
|
| 411 |
+
self.mass_history.append(cur.clone())
|
| 412 |
+
if len(self.mass_history) > self.sim_hist:
|
| 413 |
+
self.mass_history = self.mass_history[-self.sim_hist:]
|
| 414 |
+
if len(self.mass_history) >= self.min_sim_hist:
|
| 415 |
+
self.similarity = self._build_sim()
|
| 416 |
+
if self.policy == "knn":
|
| 417 |
+
self.scores = self._knn_scores(self.active, cur)
|
| 418 |
+
else:
|
| 419 |
+
self.scores = self.ema.clone()
|
| 420 |
+
return cur
|
| 421 |
+
|
| 422 |
+
def _build_sim(self):
|
| 423 |
+
H = torch.stack(self.mass_history)
|
| 424 |
+
H = (H - H.mean(0, keepdim=True)) / (H.std(0, keepdim=True) + 1e-6)
|
| 425 |
+
S = torch.clamp((H.T @ H) / max(1, H.shape[0]-1), min=0)
|
| 426 |
+
S.fill_diagonal_(0)
|
| 427 |
+
ok = torch.zeros_like(S, dtype=torch.bool)
|
| 428 |
+
for _, ids in self.m2ids.items():
|
| 429 |
+
ok[ids[:,None], ids[None,:]] = True
|
| 430 |
+
return torch.where(ok, S, torch.zeros_like(S))
|
| 431 |
+
|
| 432 |
+
def _knn_scores(self, active_mask, cur):
|
| 433 |
+
if self.similarity is None: return self.ema.clone()
|
| 434 |
+
sc = self.ema.clone()
|
| 435 |
+
sc[active_mask] = cur[active_mask]
|
| 436 |
+
aidx = active_mask.nonzero(as_tuple=False).flatten()
|
| 437 |
+
iidx = (~active_mask).nonzero(as_tuple=False).flatten()
|
| 438 |
+
if aidx.numel() == 0: return sc
|
| 439 |
+
S = self.similarity
|
| 440 |
+
for i in iidx.tolist():
|
| 441 |
+
w = S[i, aidx]
|
| 442 |
+
if w.sum() <= 1e-12: continue
|
| 443 |
+
kk = min(self.knn_k, w.numel())
|
| 444 |
+
top = torch.topk(w, k=kk)
|
| 445 |
+
sc[i] = (top.values * cur[aidx[top.indices]]).sum() / (top.values.sum() + 1e-12)
|
| 446 |
+
return sc
|
| 447 |
+
|
| 448 |
+
@torch.no_grad()
|
| 449 |
+
def oracle_scores(self):
|
| 450 |
+
"""Compute dense gradient magnitudes per chunk (requires dense grads already computed)."""
|
| 451 |
+
sc = torch.zeros(self.nc, device=self.dev)
|
| 452 |
+
for m, ids in self.m2ids.items():
|
| 453 |
+
if m.weight.grad is None: continue
|
| 454 |
+
s = m.weight.grad.square().view(len(ids), self.cs, -1).sum((1,2))
|
| 455 |
+
if m.bias is not None and m.bias.grad is not None:
|
| 456 |
+
s += m.bias.grad.square().view(len(ids), self.cs).sum(1)
|
| 457 |
+
sc[ids] = torch.sqrt(s + 1e-30)
|
| 458 |
+
return sc
|
| 459 |
+
|
| 460 |
+
def measure_overlap(self, k):
|
| 461 |
+
"""Jaccard and recall of current active vs oracle top-k."""
|
| 462 |
+
oracle = set(torch.topk(self.oracle_scores(), k=k).indices.tolist())
|
| 463 |
+
pred = set(self.active.nonzero(as_tuple=True)[0].tolist())
|
| 464 |
+
if not oracle or not pred: return 0., 0.
|
| 465 |
+
inter = oracle & pred
|
| 466 |
+
return len(inter)/len(oracle|pred), len(inter)/len(oracle)
|
| 467 |
+
|
| 468 |
+
# ββββββββββββββββββββββοΏ½οΏ½ββββββββββββββββββββββββββββββββββββββββ
|
| 469 |
+
# CHUNKED ADAM WITH PHANTOM/FROZEN MODES
|
| 470 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 471 |
+
|
| 472 |
+
class ChunkedAdam:
|
| 473 |
+
"""
|
| 474 |
+
Adam with two modes for inactive chunks:
|
| 475 |
+
phantom: standard β m,v decay even on zero grad (default, original behavior)
|
| 476 |
+
frozen: m,v state completely frozen for inactive chunks
|
| 477 |
+
"""
|
| 478 |
+
def __init__(self, model, lr=3e-4, cs=64, momentum_mode="phantom"):
|
| 479 |
+
self.model, self.lr, self.cs = model, lr, cs
|
| 480 |
+
self.momentum_mode = momentum_mode # "phantom" or "frozen"
|
| 481 |
+
self.state = {}
|
| 482 |
+
self.p2m = {}
|
| 483 |
+
for m in get_sparse_linears(model):
|
| 484 |
+
if m.weight is not None: self.p2m[m.weight] = m
|
| 485 |
+
if m.bias is not None: self.p2m[m.bias] = m
|
| 486 |
+
|
| 487 |
+
def zero_grad(self):
|
| 488 |
+
for p in self.model.parameters(): p.grad = None
|
| 489 |
+
|
| 490 |
+
@torch.no_grad()
|
| 491 |
+
def step(self):
|
| 492 |
+
for p in self.model.parameters():
|
| 493 |
+
if p.grad is None: continue
|
| 494 |
+
if p not in self.state:
|
| 495 |
+
self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
|
| 496 |
+
m, v = self.state[p]["m"], self.state[p]["v"]
|
| 497 |
+
sm = self.p2m.get(p)
|
| 498 |
+
ac = getattr(sm, 'active_chunks', None) if sm else None
|
| 499 |
+
|
| 500 |
+
if ac is None:
|
| 501 |
+
# Dense parameter (LN, embeddings, lm_head) β always full update
|
| 502 |
+
m.mul_(0.9).add_(p.grad, alpha=0.1)
|
| 503 |
+
v.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
|
| 504 |
+
p.sub_(m / (torch.sqrt(v) + 1e-8), alpha=self.lr)
|
| 505 |
+
else:
|
| 506 |
+
if self.momentum_mode == "phantom":
|
| 507 |
+
# PHANTOM: update ALL chunks' moments, but only active get real gradients.
|
| 508 |
+
# Inactive chunks see grad=0, so m decays and v decays.
|
| 509 |
+
# This is the original behavior.
|
| 510 |
+
m.mul_(0.9).add_(p.grad, alpha=0.1)
|
| 511 |
+
v.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
|
| 512 |
+
# But only update weights for active chunks
|
| 513 |
+
for c in ac.tolist():
|
| 514 |
+
s, e = c*self.cs, (c+1)*self.cs
|
| 515 |
+
p.data[s:e].sub_(m[s:e] / (torch.sqrt(v[s:e]) + 1e-8), alpha=self.lr)
|
| 516 |
+
elif self.momentum_mode == "frozen":
|
| 517 |
+
# FROZEN: only touch m,v,p for active chunks. Inactive state is untouched.
|
| 518 |
+
for c in ac.tolist():
|
| 519 |
+
s, e = c*self.cs, (c+1)*self.cs
|
| 520 |
+
g = p.grad[s:e]
|
| 521 |
+
m[s:e].mul_(0.9).add_(g, alpha=0.1)
|
| 522 |
+
v[s:e].mul_(0.999).addcmul_(g, g, value=0.001)
|
| 523 |
+
p.data[s:e].sub_(m[s:e] / (torch.sqrt(v[s:e]) + 1e-8), alpha=self.lr)
|
| 524 |
+
|
| 525 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 526 |
+
# EVALUATION
|
| 527 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 528 |
+
|
| 529 |
+
@torch.no_grad()
|
| 530 |
+
def evaluate(model, corpus, bs, n=20, seed=9999):
|
| 531 |
+
model.eval()
|
| 532 |
+
losses = []
|
| 533 |
+
for i in range(n):
|
| 534 |
+
_, l = model(*corpus.get_batch("val", bs, make_gen(seed+i)))
|
| 535 |
+
losses.append(l.item())
|
| 536 |
+
model.train()
|
| 537 |
+
avg = sum(losses)/len(losses)
|
| 538 |
+
return avg, math.exp(min(avg, 20))
|
| 539 |
+
|
| 540 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 541 |
+
# SINGLE TRAINING RUN
|
| 542 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 543 |
+
|
| 544 |
+
def run(policy, bwd_mode, steps, bs, block_size, nl, nh, d, cs,
|
| 545 |
+
active_frac, wu, an, lr, device, seed, backend="triton",
|
| 546 |
+
momentum_mode="phantom", ffn_mult=4,
|
| 547 |
+
measure_oracle=False, oracle_interval=50):
|
| 548 |
+
"""Run one training config. Returns dict of results."""
|
| 549 |
+
torch.manual_seed(seed)
|
| 550 |
+
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
|
| 551 |
+
random.seed(seed)
|
| 552 |
+
|
| 553 |
+
corpus = Corpus.get(block_size, device)
|
| 554 |
+
model = GPT(corpus.vocab_size, block_size, nl, nh, d, 0.1, ffn_mult).to(device)
|
| 555 |
+
for m in get_sparse_linears(model):
|
| 556 |
+
m.chunk_size = cs
|
| 557 |
+
m.backend = backend
|
| 558 |
+
|
| 559 |
+
is_dense = (policy == "dense")
|
| 560 |
+
sched = None if is_dense else ChunkScheduler(model, policy, active_frac, cs, device)
|
| 561 |
+
opt = ChunkedAdam(model, lr=lr, cs=cs, momentum_mode=momentum_mode)
|
| 562 |
+
|
| 563 |
+
np_ = model.nparams()
|
| 564 |
+
overlaps = []
|
| 565 |
+
|
| 566 |
+
torch.cuda.synchronize() if device == "cuda" else None
|
| 567 |
+
t0 = time.perf_counter()
|
| 568 |
+
|
| 569 |
+
for step in range(steps):
|
| 570 |
+
x, y = corpus.get_batch("train", bs, make_gen(step))
|
| 571 |
+
|
| 572 |
+
if is_dense:
|
| 573 |
+
for m in get_sparse_linears(model):
|
| 574 |
+
m.sparse_enabled = False; m.active_chunks = None
|
| 575 |
+
else:
|
| 576 |
+
sched.choose(step, wu, an)
|
| 577 |
+
for m in get_sparse_linears(model):
|
| 578 |
+
m.sparse_enabled = True
|
| 579 |
+
m.sparse_dx = (bwd_mode == "sparse_dX")
|
| 580 |
+
|
| 581 |
+
opt.zero_grad()
|
| 582 |
+
_, loss = model(x, y)
|
| 583 |
+
loss.backward()
|
| 584 |
+
|
| 585 |
+
if sched:
|
| 586 |
+
sched.update(step, wu)
|
| 587 |
+
|
| 588 |
+
# Oracle overlap measurement
|
| 589 |
+
if measure_oracle and step % oracle_interval == 0 and step >= wu + an:
|
| 590 |
+
saved = {p: p.grad.clone() for p in model.parameters() if p.grad is not None}
|
| 591 |
+
for m in get_sparse_linears(model): m.sparse_enabled = False
|
| 592 |
+
for p in model.parameters(): p.grad = None
|
| 593 |
+
_, lo = model(x, y); lo.backward()
|
| 594 |
+
k = max(1, int(active_frac * sched.nc))
|
| 595 |
+
j, r = sched.measure_overlap(k)
|
| 596 |
+
overlaps.append((step, j, r))
|
| 597 |
+
for p in model.parameters():
|
| 598 |
+
if p in saved: p.grad = saved[p]
|
| 599 |
+
for m in get_sparse_linears(model): m.sparse_enabled = True
|
| 600 |
+
|
| 601 |
+
opt.step()
|
| 602 |
+
|
| 603 |
+
if step % 200 == 0:
|
| 604 |
+
print(f" step {step}/{steps} loss={loss.item():.4f}")
|
| 605 |
+
|
| 606 |
+
torch.cuda.synchronize() if device == "cuda" else None
|
| 607 |
+
wall = time.perf_counter() - t0
|
| 608 |
+
|
| 609 |
+
for m in get_sparse_linears(model): m.sparse_enabled = False
|
| 610 |
+
vl, vp = evaluate(model, corpus, bs, n=30)
|
| 611 |
+
|
| 612 |
+
del model; torch.cuda.empty_cache() if device == "cuda" else None
|
| 613 |
+
|
| 614 |
+
return {
|
| 615 |
+
"val_loss": vl, "val_ppl": vp, "wall_time": wall,
|
| 616 |
+
"ms_per_step": 1000*wall/steps, "n_params": np_,
|
| 617 |
+
"train_loss_final": loss.item(), "overlaps": overlaps,
|
| 618 |
+
}
|
| 619 |
+
|
| 620 |
+
def run_seeds(cfg, seeds):
|
| 621 |
+
results = []
|
| 622 |
+
for s in seeds:
|
| 623 |
+
cfg["seed"] = s
|
| 624 |
+
results.append(run(**cfg))
|
| 625 |
+
vls = [r["val_loss"] for r in results]
|
| 626 |
+
ml = sum(vls)/len(vls)
|
| 627 |
+
sl = (sum((x-ml)**2 for x in vls)/max(1,len(vls)-1))**0.5
|
| 628 |
+
return {"mean_loss": ml, "std_loss": sl, "results": results,
|
| 629 |
+
"mean_ms": sum(r["ms_per_step"] for r in results)/len(results)}
|
| 630 |
+
|
| 631 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 632 |
+
# EXPERIMENT 1: PHANTOM MOMENTUM ABLATION
|
| 633 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 634 |
+
|
| 635 |
+
def exp_phantom_momentum(device, steps, seeds, d, nl, nh, bs, block_size, cs, af, wu, an, lr, backend):
|
| 636 |
+
print("\n" + "="*80)
|
| 637 |
+
print("EXPERIMENT 1: Phantom Momentum Ablation")
|
| 638 |
+
print("="*80)
|
| 639 |
+
|
| 640 |
+
base = dict(bwd_mode="full_dX", steps=steps, bs=bs, block_size=block_size,
|
| 641 |
+
nl=nl, nh=nh, d=d, cs=cs, active_frac=af, wu=wu, an=an,
|
| 642 |
+
lr=lr, device=device, backend=backend)
|
| 643 |
+
|
| 644 |
+
configs = [
|
| 645 |
+
("dense", "dense", "phantom"),
|
| 646 |
+
("ema+phantom", "ema", "phantom"),
|
| 647 |
+
("ema+frozen", "ema", "frozen"),
|
| 648 |
+
("knn+phantom", "knn", "phantom"),
|
| 649 |
+
("knn+frozen", "knn", "frozen"),
|
| 650 |
+
("random+phantom", "random", "phantom"),
|
| 651 |
+
("random+frozen", "random", "frozen"),
|
| 652 |
+
]
|
| 653 |
+
|
| 654 |
+
results = {}
|
| 655 |
+
for name, policy, mm in configs:
|
| 656 |
+
print(f"\n--- {name} ---")
|
| 657 |
+
cfg = {**base, "policy": policy, "momentum_mode": mm}
|
| 658 |
+
results[name] = run_seeds(cfg, seeds)
|
| 659 |
+
|
| 660 |
+
print(f"\n{'Method':<22} | {'Val Loss':>18} | {'ms/step':>10}")
|
| 661 |
+
print("-"*55)
|
| 662 |
+
for name, _, _ in configs:
|
| 663 |
+
r = results[name]
|
| 664 |
+
print(f"{name:<22} | {r['mean_loss']:.4f} Β± {r['std_loss']:.4f} | {r['mean_ms']:>9.1f}")
|
| 665 |
+
|
| 666 |
+
return results
|
| 667 |
+
|
| 668 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 669 |
+
# EXPERIMENT 2: COMPUTE-MATCHED BASELINES
|
| 670 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 671 |
+
|
| 672 |
+
def exp_compute_matched(device, steps, seeds, d, nl, nh, bs, block_size, cs, af, wu, an, lr, backend):
|
| 673 |
+
print("\n" + "="*80)
|
| 674 |
+
print("EXPERIMENT 2: Compute-Matched Baselines")
|
| 675 |
+
print("="*80)
|
| 676 |
+
|
| 677 |
+
base = dict(bwd_mode="full_dX", steps=steps, bs=bs, block_size=block_size,
|
| 678 |
+
nl=nl, nh=nh, d=d, cs=cs, active_frac=af, wu=wu, an=an,
|
| 679 |
+
lr=lr, device=device, backend=backend, momentum_mode="phantom")
|
| 680 |
+
|
| 681 |
+
# 1. Sparse reference
|
| 682 |
+
print("\n--- Sparse (EMA, reference) ---")
|
| 683 |
+
sparse_r = run_seeds({**base, "policy": "ema"}, seeds)
|
| 684 |
+
|
| 685 |
+
# 2. Dense at same steps
|
| 686 |
+
print("\n--- Dense (same steps) ---")
|
| 687 |
+
dense_same = run_seeds({**base, "policy": "dense"}, seeds)
|
| 688 |
+
|
| 689 |
+
# 3. Dense at compute-matched steps
|
| 690 |
+
# Sparse does ~70% of dense FLOPs (fwd dense + dX dense + dW at 10%)
|
| 691 |
+
ratio = (1.0 + 1.0 + af) / 3.0
|
| 692 |
+
matched_steps = int(steps * ratio)
|
| 693 |
+
print(f"\n--- Dense (compute-matched, {matched_steps} steps) ---")
|
| 694 |
+
dense_matched = run_seeds({**base, "policy": "dense", "steps": matched_steps}, seeds)
|
| 695 |
+
|
| 696 |
+
# 4. Natively smaller dense model: FFN multiplier = 4 * af = 0.4 (rounded)
|
| 697 |
+
# This gives a model with ~10% of the FFN capacity
|
| 698 |
+
small_ffn_mult = max(1, round(4 * af)) # 4*0.1 = 0.4, round to 1
|
| 699 |
+
print(f"\n--- Small dense (ffn_mult={small_ffn_mult}, capacity-matched) ---")
|
| 700 |
+
dense_small = run_seeds({**base, "policy": "dense", "ffn_mult": small_ffn_mult}, seeds)
|
| 701 |
+
|
| 702 |
+
results = {
|
| 703 |
+
"sparse_ema": sparse_r,
|
| 704 |
+
"dense_same_steps": dense_same,
|
| 705 |
+
f"dense_matched_{matched_steps}steps": dense_matched,
|
| 706 |
+
f"dense_small_ffn{small_ffn_mult}": dense_small,
|
| 707 |
+
}
|
| 708 |
+
|
| 709 |
+
print(f"\n{'Method':<35} | {'Steps':>6} | {'Params':>8} | {'Val Loss':>18} | {'ms/step':>10}")
|
| 710 |
+
print("-"*90)
|
| 711 |
+
for name, r in results.items():
|
| 712 |
+
np_ = r["results"][0]["n_params"]
|
| 713 |
+
st = r["results"][0].get("steps", steps) if "steps" in name else steps
|
| 714 |
+
# read actual steps from config β approximate
|
| 715 |
+
print(f"{name:<35} | {st if 'matched' not in name else matched_steps:>6} | {np_/1e6:>7.1f}M | {r['mean_loss']:.4f} Β± {r['std_loss']:.4f} | {r['mean_ms']:>9.1f}")
|
| 716 |
+
|
| 717 |
+
return results
|
| 718 |
+
|
| 719 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 720 |
+
# EXPERIMENT 3: PREDICTOR ACCURACY (EMA vs KNN vs Oracle)
|
| 721 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 722 |
+
|
| 723 |
+
def exp_predictor_accuracy(device, steps, seeds, d, nl, nh, bs, block_size, cs, af, wu, an, lr, backend):
|
| 724 |
+
print("\n" + "="*80)
|
| 725 |
+
print("EXPERIMENT 3: Predictor Accuracy (EMA vs KNN vs Oracle)")
|
| 726 |
+
print("="*80)
|
| 727 |
+
|
| 728 |
+
base = dict(bwd_mode="full_dX", steps=steps, bs=bs, block_size=block_size,
|
| 729 |
+
nl=nl, nh=nh, d=d, cs=cs, active_frac=af, wu=wu, an=an,
|
| 730 |
+
lr=lr, device=device, backend=backend, momentum_mode="phantom",
|
| 731 |
+
measure_oracle=True, oracle_interval=25)
|
| 732 |
+
|
| 733 |
+
results = {}
|
| 734 |
+
for policy in ["ema", "knn", "random"]:
|
| 735 |
+
print(f"\n--- {policy} ---")
|
| 736 |
+
results[policy] = run_seeds({**base, "policy": policy}, seeds)
|
| 737 |
+
|
| 738 |
+
# Aggregate overlaps
|
| 739 |
+
for policy in ["ema", "knn", "random"]:
|
| 740 |
+
print(f"\n{policy.upper()} predictor overlap:")
|
| 741 |
+
print(f" {'Step':>6} | {'Jaccard':>10} | {'Recall':>10}")
|
| 742 |
+
sd = defaultdict(lambda: {"j": [], "r": []})
|
| 743 |
+
for res in results[policy]["results"]:
|
| 744 |
+
for s, j, r in res["overlaps"]:
|
| 745 |
+
sd[s]["j"].append(j); sd[s]["r"].append(r)
|
| 746 |
+
for s in sorted(sd):
|
| 747 |
+
mj = sum(sd[s]["j"])/len(sd[s]["j"])
|
| 748 |
+
mr = sum(sd[s]["r"])/len(sd[s]["r"])
|
| 749 |
+
print(f" {s:>6} | {mj:>10.4f} | {mr:>10.4f}")
|
| 750 |
+
|
| 751 |
+
print(f"\n{'Policy':<10} | {'Val Loss':>18} | {'ms/step':>10}")
|
| 752 |
+
print("-"*45)
|
| 753 |
+
for p in ["ema", "knn", "random"]:
|
| 754 |
+
r = results[p]
|
| 755 |
+
print(f"{p:<10} | {r['mean_loss']:.4f} Β± {r['std_loss']:.4f} | {r['mean_ms']:>9.1f}")
|
| 756 |
+
|
| 757 |
+
return results
|
| 758 |
+
|
| 759 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 760 |
+
# MAIN
|
| 761 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 762 |
+
|
| 763 |
+
ALL_EXPS = {
|
| 764 |
+
"phantom_momentum": exp_phantom_momentum,
|
| 765 |
+
"compute_matched": exp_compute_matched,
|
| 766 |
+
"predictor_accuracy": exp_predictor_accuracy,
|
| 767 |
+
}
|
| 768 |
+
|
| 769 |
+
def main():
|
| 770 |
+
p = argparse.ArgumentParser()
|
| 771 |
+
p.add_argument("--experiment", default="all", choices=list(ALL_EXPS)+["all"])
|
| 772 |
+
p.add_argument("--device", default="cuda")
|
| 773 |
+
p.add_argument("--steps", type=int, default=1000)
|
| 774 |
+
p.add_argument("--seeds", default="42,123,456")
|
| 775 |
+
p.add_argument("--n_embd", type=int, default=1024)
|
| 776 |
+
p.add_argument("--n_layer", type=int, default=4)
|
| 777 |
+
p.add_argument("--n_head", type=int, default=8)
|
| 778 |
+
p.add_argument("--batch_size", type=int, default=8)
|
| 779 |
+
p.add_argument("--block_size", type=int, default=256)
|
| 780 |
+
p.add_argument("--chunk_size", type=int, default=64)
|
| 781 |
+
p.add_argument("--active_fraction", type=float, default=0.10)
|
| 782 |
+
p.add_argument("--warmup_steps", type=int, default=50)
|
| 783 |
+
p.add_argument("--anneal_steps", type=int, default=200)
|
| 784 |
+
p.add_argument("--lr", type=float, default=3e-4)
|
| 785 |
+
p.add_argument("--backend", default="triton", choices=["triton", "torch"])
|
| 786 |
+
p.add_argument("--output_dir", default="results")
|
| 787 |
+
args = p.parse_args()
|
| 788 |
+
|
| 789 |
+
seeds = [int(s) for s in args.seeds.split(",")]
|
| 790 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 791 |
+
|
| 792 |
+
if args.device == "cuda" and torch.cuda.is_available():
|
| 793 |
+
print(f"GPU: {torch.cuda.get_device_name()} | VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")
|
| 794 |
+
print(f"Config: d={args.n_embd} nl={args.n_layer} nh={args.n_head} steps={args.steps} seeds={seeds}")
|
| 795 |
+
print(f" cs={args.chunk_size} af={args.active_fraction} backend={args.backend}")
|
| 796 |
+
|
| 797 |
+
shared = dict(device=args.device, steps=args.steps, seeds=seeds,
|
| 798 |
+
d=args.n_embd, nl=args.n_layer, nh=args.n_head,
|
| 799 |
+
bs=args.batch_size, block_size=args.block_size,
|
| 800 |
+
cs=args.chunk_size, af=args.active_fraction,
|
| 801 |
+
wu=args.warmup_steps, an=args.anneal_steps,
|
| 802 |
+
lr=args.lr, backend=args.backend)
|
| 803 |
+
|
| 804 |
+
exps = ALL_EXPS if args.experiment == "all" else {args.experiment: ALL_EXPS[args.experiment]}
|
| 805 |
+
t0 = time.time()
|
| 806 |
+
|
| 807 |
+
for name, fn in exps.items():
|
| 808 |
+
print(f"\n{'#'*80}\n# {name} ({(time.time()-t0)/60:.1f}m elapsed)\n{'#'*80}")
|
| 809 |
+
sys.stdout.flush()
|
| 810 |
+
result = fn(**shared)
|
| 811 |
+
|
| 812 |
+
def ser(o):
|
| 813 |
+
if isinstance(o, dict): return {str(k): ser(v) for k,v in o.items()}
|
| 814 |
+
if isinstance(o, list): return [ser(x) for x in o]
|
| 815 |
+
return o
|
| 816 |
+
|
| 817 |
+
with open(os.path.join(args.output_dir, f"{name}.json"), "w") as f:
|
| 818 |
+
json.dump(ser(result), f, indent=2, default=str)
|
| 819 |
+
print(f"β {name} saved to {args.output_dir}/{name}.json")
|
| 820 |
+
|
| 821 |
+
print(f"\n{'='*80}\nALL COMPLETE in {(time.time()-t0)/60:.1f} minutes\n{'='*80}")
|
| 822 |
+
|
| 823 |
+
if __name__ == "__main__":
|
| 824 |
+
main()
|