theapemachine commited on
Commit
bc1b8eb
·
1 Parent(s): 514b330

Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.

Browse files
Files changed (34) hide show
  1. e2e_bench.py → experiments/e2e_bench.py +0 -0
  2. e2e_full.py → experiments/e2e_full.py +0 -0
  3. experiments/pyproject.toml +19 -0
  4. experiments/sparse_linear_v10_metal/README.md +62 -0
  5. experiments/sparse_linear_v10_metal/setup.py +45 -0
  6. experiments/sparse_linear_v10_metal/sparse_linear.metal +65 -0
  7. experiments/sparse_linear_v10_metal/sparse_linear_ops.mm +158 -0
  8. experiments/sparse_linear_v10_metal/sparse_transformer_v10.py +1112 -0
  9. experiments/sparse_linear_v11_gather_vs_metal/README.md +62 -0
  10. experiments/sparse_linear_v11_gather_vs_metal/input.txt +0 -0
  11. experiments/sparse_linear_v11_gather_vs_metal/setup.py +15 -0
  12. experiments/sparse_linear_v11_gather_vs_metal/sparse_linear.metal +62 -0
  13. experiments/sparse_linear_v11_gather_vs_metal/sparse_linear_ops.mm +168 -0
  14. experiments/sparse_linear_v11_gather_vs_metal/sparse_transformer_v11.py +406 -0
  15. experiments/sparse_linear_v11_gather_vs_metal/sparse_transformer_v13.py +419 -0
  16. experiments/sparse_linear_v11_gather_vs_metal/tiny.py +441 -0
  17. experiments/sparse_transformer_v15_inactive_prediction.py +729 -0
  18. experiments/sparse_transformer_v16_sensor_scheduler.py +677 -0
  19. experiments/sparse_transformer_v17_radar_scheduler.py +725 -0
  20. experiments/sparse_transformer_v6.py +596 -0
  21. experiments/sparse_transformer_v7.py +780 -0
  22. experiments/sparse_transformer_v8.py +943 -0
  23. experiments/sparse_transformer_v9.py +1042 -0
  24. experiments/surprise_topk_gradient_prototype-v2.py +418 -0
  25. experiments/surprise_topk_gradient_prototype-v3.py +487 -0
  26. experiments/surprise_topk_gradient_prototype-v4.py +571 -0
  27. experiments/surprise_topk_gradient_prototype-v5.py +563 -0
  28. experiments/surprise_topk_gradient_prototype.py +426 -0
  29. triton_sparse.py → experiments/triton_sparse.py +0 -0
  30. triton_v2.py → experiments/triton_v2.py +0 -0
  31. experiments/uv.lock +0 -0
  32. paper/main.tex +307 -0
  33. sparse_transformer_v18_fast_knn.py +459 -0
  34. sparse_transformer_v18_fast_knn_triton.py +1044 -0
e2e_bench.py → experiments/e2e_bench.py RENAMED
File without changes
e2e_full.py → experiments/e2e_full.py RENAMED
File without changes
experiments/pyproject.toml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "surprise-topk-gradient"
7
+ version = "0.1.0"
8
+ description = "Prototype for surprise Top-K gradient training experiments"
9
+ requires-python = ">=3.10"
10
+ dependencies = [
11
+ "numpy>=1.24",
12
+ "torch>=2.0",
13
+ ]
14
+
15
+ [project.scripts]
16
+ surprise-topk-gradient = "surprise_topk_gradient_prototype:main"
17
+
18
+ [tool.setuptools]
19
+ py-modules = ["surprise_topk_gradient_prototype"]
experiments/sparse_linear_v10_metal/README.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sparse Transformer v10: Metal-backed active-row Linear backward benchmark
2
+
3
+ This bundle is the first empirical optimization test rather than only a correctness prototype.
4
+ It compares dense Transformer training against sparse active-row backward variants and reports
5
+ validation loss plus wall-clock metrics (`step_ms`, `tokens_per_s`).
6
+
7
+ ## Files
8
+
9
+ - `sparse_transformer_v10.py` — training + benchmark harness.
10
+ - `sparse_linear.metal` — Metal kernels for active-row `dW/db` and sparse `dX`.
11
+ - `sparse_linear_ops.mm` — PyTorch/MPS extension glue.
12
+ - `setup.py` — builds the extension and compiles the `.metallib`.
13
+
14
+ ## Install the Metal extension
15
+
16
+ From this directory on macOS with PyTorch MPS available:
17
+
18
+ ```bash
19
+ python3 -m pip install -e .
20
+ ```
21
+
22
+ This uses `xcrun metal` and `xcrun metallib`, so Xcode command-line tools are required.
23
+
24
+ ## Correctness/sanity run with PyTorch fallback
25
+
26
+ ```bash
27
+ python3 sparse_transformer_v10.py \
28
+ --device mps \
29
+ --steps 2000 \
30
+ --active_fractions 0.05 0.02 \
31
+ --warmup_steps_list 5 \
32
+ --policies predicted_magnitude random \
33
+ --backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
34
+ --audit_every 0 \
35
+ --kernel_backend torch \
36
+ --benchmark_sync
37
+ ```
38
+
39
+ ## Metal benchmark run
40
+
41
+ ```bash
42
+ python3 sparse_transformer_v10.py \
43
+ --device mps \
44
+ --steps 2000 \
45
+ --active_fractions 0.05 0.02 \
46
+ --warmup_steps_list 5 \
47
+ --policies predicted_magnitude random \
48
+ --backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
49
+ --audit_every 0 \
50
+ --kernel_backend metal \
51
+ --benchmark_sync
52
+ ```
53
+
54
+ ## Empirical pass/fail criteria
55
+
56
+ A genuine optimization would show:
57
+
58
+ 1. `predicted_magnitude` validation loss close to the PyTorch fallback v9/v10 result.
59
+ 2. `predicted_magnitude` much better than `random` at the same active fraction.
60
+ 3. `--kernel_backend metal` has lower `step_ms` or higher `tokens_per_s` than `--kernel_backend torch` and ideally dense baseline.
61
+
62
+ The first Metal kernels are intentionally simple and fp32-only. They are meant to prove or disprove the acceleration path before investing in tiled half-precision kernels.
experiments/sparse_linear_v10_metal/setup.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import subprocess
3
+ from pathlib import Path
4
+ import torch
5
+ from setuptools import setup
6
+ from torch.utils.cpp_extension import BuildExtension, CppExtension
7
+
8
+ ROOT = Path(__file__).parent.resolve()
9
+ # PyTorch dylibs (libc10, libtorch, …) are @rpath-linked; embed torch/lib so import works from any cwd.
10
+ _TORCH_LIB = Path(torch.__file__).resolve().parent / "lib"
11
+ METAL_SRC = ROOT / "sparse_linear.metal"
12
+ AIR = ROOT / "sparse_linear.air"
13
+ METALLIB = ROOT / "sparse_linear_ops.metallib"
14
+
15
+ class MetalBuildExt(BuildExtension):
16
+ def run(self):
17
+ if shutil.which("xcrun") is None:
18
+ raise RuntimeError("xcrun not found. Install Xcode command line tools.")
19
+ subprocess.check_call(["xcrun", "-sdk", "macosx", "metal", "-c", str(METAL_SRC), "-o", str(AIR)])
20
+ subprocess.check_call(["xcrun", "-sdk", "macosx", "metallib", str(AIR), "-o", str(METALLIB)])
21
+ super().run()
22
+ # Copy metallib next to the built extension .so.
23
+ build_lib = Path(self.build_lib)
24
+ for so in build_lib.rglob("sparse_linear_metal*.so"):
25
+ shutil.copy2(METALLIB, so.parent / METALLIB.name)
26
+
27
+ setup(
28
+ name="sparse_linear_metal",
29
+ version="0.1.0",
30
+ ext_modules=[
31
+ CppExtension(
32
+ name="sparse_linear_metal",
33
+ sources=["sparse_linear_ops.mm"],
34
+ extra_compile_args={"cxx": ["-std=c++17", "-ObjC++", "-fobjc-arc"]},
35
+ extra_link_args=[
36
+ "-framework",
37
+ "Metal",
38
+ "-framework",
39
+ "Foundation",
40
+ f"-Wl,-rpath,{_TORCH_LIB}",
41
+ ],
42
+ )
43
+ ],
44
+ cmdclass={"build_ext": MetalBuildExt},
45
+ )
experiments/sparse_linear_v10_metal/sparse_linear.metal ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <metal_stdlib>
2
+ using namespace metal;
3
+
4
+ struct SparseLinearParams {
5
+ uint32_t N;
6
+ uint32_t In;
7
+ uint32_t Out;
8
+ uint32_t dummy; // alignment padding
9
+ };
10
+
11
+ kernel void sparse_linear_grad_w_float(
12
+ device const float* x [[buffer(0)]], // [N, In]
13
+ device const float* gy [[buffer(1)]], // [N, Out]
14
+ device const bool* active_mask [[buffer(2)]], // [Out] boolean mask
15
+ device float* grad_w [[buffer(3)]], //[Out, In], zeroed by caller
16
+ device float* grad_b [[buffer(4)]], // [Out], zeroed by caller
17
+ constant SparseLinearParams& p [[buffer(5)]],
18
+ uint2 tid [[thread_position_in_grid]])
19
+ {
20
+ uint c = tid.x;
21
+ uint row = tid.y;
22
+
23
+ // Bounds check
24
+ if (row >= p.Out || c >= p.In) return;
25
+
26
+ // The magic: if the row isn't active, the thread exits instantly.
27
+ // No CPU sync required.
28
+ if (!active_mask[row]) return;
29
+
30
+ float acc = 0.0f;
31
+ for (uint n = 0; n < p.N; ++n) {
32
+ acc += gy[n * p.Out + row] * x[n * p.In + c];
33
+ }
34
+ grad_w[row * p.In + c] = acc;
35
+
36
+ // Bias calculation (could be optimized further, but fine for now)
37
+ if (c == 0) {
38
+ float bacc = 0.0f;
39
+ for (uint n = 0; n < p.N; ++n) {
40
+ bacc += gy[n * p.Out + row];
41
+ }
42
+ grad_b[row] = bacc;
43
+ }
44
+ }
45
+
46
+ kernel void sparse_linear_grad_x_float(
47
+ device const float* gy [[buffer(0)]], // [N, Out]
48
+ device const float* weight [[buffer(1)]], //[Out, In]
49
+ device const bool* active_mask [[buffer(2)]], // [Out] boolean mask
50
+ device float* grad_x [[buffer(3)]], // [N, In], zeroed by caller
51
+ constant SparseLinearParams& p [[buffer(4)]],
52
+ uint2 tid [[thread_position_in_grid]])
53
+ {
54
+ uint c = tid.x;
55
+ uint n = tid.y;
56
+ if (n >= p.N || c >= p.In) return;
57
+
58
+ float acc = 0.0f;
59
+ for (uint row = 0; row < p.Out; ++row) {
60
+ if (active_mask[row]) {
61
+ acc += gy[n * p.Out + row] * weight[row * p.In + c];
62
+ }
63
+ }
64
+ grad_x[n * p.In + c] = acc;
65
+ }
experiments/sparse_linear_v10_metal/sparse_linear_ops.mm ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <ATen/ATen.h>
3
+ #include <ATen/mps/MPSStream.h>
4
+ #include <ATen/mps/MPSDevice.h>
5
+ #include <c10/util/Exception.h>
6
+ #include <filesystem>
7
+ #include <dlfcn.h>
8
+ #import <Metal/Metal.h>
9
+ #import <Foundation/Foundation.h>
10
+ #include <mutex>
11
+
12
+ namespace fs = std::filesystem;
13
+
14
+ namespace {
15
+ struct SparseLinearParams {
16
+ uint32_t N;
17
+ uint32_t In;
18
+ uint32_t Out;
19
+ uint32_t dummy;
20
+ };
21
+
22
+ static id<MTLLibrary> g_lib = nil;
23
+ static id<MTLComputePipelineState> g_pipeline_grad_w = nil;
24
+ static id<MTLComputePipelineState> g_pipeline_grad_x = nil;
25
+ static std::mutex g_mutex;
26
+
27
+ static std::string metallib_path_for_this_module() {
28
+ Dl_info info;
29
+ if (dladdr((void*)&metallib_path_for_this_module, &info) == 0 || info.dli_fname == nullptr) return std::string();
30
+ fs::path so_path(info.dli_fname);
31
+ return (so_path.parent_path() / "sparse_linear_ops.metallib").string();
32
+ }
33
+
34
+ static void ensure_library_locked(id<MTLDevice> device) {
35
+ if (g_lib != nil) return;
36
+ std::string path = metallib_path_for_this_module();
37
+ TORCH_CHECK(!path.empty(), "sparse_linear_ops: failed to locate extension path via dladdr");
38
+ NSString* ns_path = [NSString stringWithUTF8String:path.c_str()];
39
+ NSURL* url = [NSURL fileURLWithPath:ns_path];
40
+ NSError* err = nil;
41
+ g_lib = [device newLibraryWithURL:url error:&err];
42
+ if (g_lib == nil) {
43
+ const char* msg = err ? [[err localizedDescription] UTF8String] : "unknown error";
44
+ TORCH_CHECK(false, "sparse_linear_ops: failed to load metallib at ", path, ": ", msg);
45
+ }
46
+ }
47
+
48
+ static id<MTLComputePipelineState> ensure_pipeline(id<MTLDevice> device, id<MTLComputePipelineState>* pipeline, const char* fn_name) {
49
+ std::lock_guard<std::mutex> lock(g_mutex);
50
+ ensure_library_locked(device);
51
+ if (*pipeline != nil) return *pipeline;
52
+ NSString* ns_fn = [NSString stringWithUTF8String:fn_name];
53
+ id<MTLFunction> fn =[g_lib newFunctionWithName:ns_fn];
54
+ TORCH_CHECK(fn != nil, "sparse_linear_ops: function `", fn_name, "` not found in metallib");
55
+ NSError* err = nil;
56
+ *pipeline = [device newComputePipelineStateWithFunction:fn error:&err];
57
+ if (*pipeline == nil) {
58
+ const char* msg = err ? [[err localizedDescription] UTF8String] : "unknown error";
59
+ TORCH_CHECK(false, "sparse_linear_ops: failed to create pipeline for ", fn_name, ": ", msg);
60
+ }
61
+ return *pipeline;
62
+ }
63
+
64
+ static inline id<MTLBuffer> storage_as_mtlbuffer(const at::Tensor& t) {
65
+ void* ctx = t.storage().data_ptr().get_context();
66
+ TORCH_CHECK(ctx != nullptr, "sparse_linear_ops: expected MPS tensor storage with MTLBuffer context");
67
+ return (__bridge id<MTLBuffer>)ctx;
68
+ }
69
+
70
+ static inline NSUInteger storage_offset_bytes(const at::Tensor& t) {
71
+ return (NSUInteger)(t.storage_offset() * (int64_t)t.element_size());
72
+ }
73
+
74
+ static void check_mps_float_contig(const at::Tensor& t, const char* name) {
75
+ TORCH_CHECK(t.device().is_mps(), name, " must be on MPS");
76
+ TORCH_CHECK(t.dtype() == at::kFloat, name, " must be float32 for v12 kernel");
77
+ TORCH_CHECK(t.is_contiguous(), name, " must be contiguous");
78
+ }
79
+ } // namespace
80
+
81
+ std::vector<at::Tensor> sparse_linear_grad_wb(at::Tensor x2d, at::Tensor gy2d, at::Tensor active_mask) {
82
+ check_mps_float_contig(x2d, "x2d");
83
+ check_mps_float_contig(gy2d, "gy2d");
84
+ TORCH_CHECK(active_mask.device().is_mps(), "active_mask must be on MPS");
85
+ TORCH_CHECK(active_mask.dtype() == at::kBool, "active_mask must be bool");
86
+ TORCH_CHECK(active_mask.is_contiguous(), "active_mask must be contiguous");
87
+
88
+ int64_t N = x2d.size(0);
89
+ int64_t In = x2d.size(1);
90
+ int64_t Out = active_mask.size(0);
91
+ TORCH_CHECK(gy2d.size(1) == Out, "gy2d width must equal active_mask size");
92
+
93
+ auto grad_w = at::zeros({Out, In}, x2d.options());
94
+ auto grad_b = at::zeros({Out}, x2d.options());
95
+
96
+ id<MTLDevice> device = (id<MTLDevice>)at::mps::MPSDevice::getInstance()->device();
97
+ id<MTLComputePipelineState> pipeline = ensure_pipeline(device, &g_pipeline_grad_w, "sparse_linear_grad_w_float");
98
+ at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
99
+
100
+ id<MTLComputeCommandEncoder> encoder = (id<MTLComputeCommandEncoder>)stream->commandEncoder();[encoder setComputePipelineState:pipeline];
101
+
102
+ auto set_tensor = [&](const at::Tensor& t, int idx) {[encoder setBuffer:storage_as_mtlbuffer(t) offset:storage_offset_bytes(t) atIndex:(NSUInteger)idx];
103
+ };
104
+ set_tensor(x2d, 0);
105
+ set_tensor(gy2d, 1);
106
+ set_tensor(active_mask, 2);
107
+ set_tensor(grad_w, 3);
108
+ set_tensor(grad_b, 4);
109
+
110
+ SparseLinearParams prm{(uint32_t)N, (uint32_t)In, (uint32_t)Out, 0};
111
+ [encoder setBytes:&prm length:sizeof(SparseLinearParams) atIndex:5];
112
+
113
+ MTLSize tg = MTLSizeMake(16, 16, 1);
114
+ MTLSize grid = MTLSizeMake((NSUInteger)((In + 15) / 16), (NSUInteger)((Out + 15) / 16), 1);[encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg];
115
+
116
+ return {grad_w, grad_b};
117
+ }
118
+
119
+ at::Tensor sparse_linear_grad_x(at::Tensor gy2d, at::Tensor weight, at::Tensor active_mask) {
120
+ check_mps_float_contig(gy2d, "gy2d");
121
+ check_mps_float_contig(weight, "weight");
122
+ TORCH_CHECK(active_mask.device().is_mps(), "active_mask must be on MPS");
123
+ TORCH_CHECK(active_mask.dtype() == at::kBool, "active_mask must be bool");
124
+ TORCH_CHECK(active_mask.is_contiguous(), "active_mask must be contiguous");
125
+
126
+ int64_t N = gy2d.size(0);
127
+ int64_t Out = gy2d.size(1);
128
+ int64_t In = weight.size(1);
129
+ TORCH_CHECK(weight.size(0) == Out, "weight out_features must match gy2d width");
130
+
131
+ auto grad_x = at::zeros({N, In}, gy2d.options());
132
+
133
+ id<MTLDevice> device = (id<MTLDevice>)at::mps::MPSDevice::getInstance()->device();
134
+ id<MTLComputePipelineState> pipeline = ensure_pipeline(device, &g_pipeline_grad_x, "sparse_linear_grad_x_float");
135
+ at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
136
+
137
+ id<MTLComputeCommandEncoder> encoder = (id<MTLComputeCommandEncoder>)stream->commandEncoder();[encoder setComputePipelineState:pipeline];
138
+
139
+ auto set_tensor = [&](const at::Tensor& t, int idx) {[encoder setBuffer:storage_as_mtlbuffer(t) offset:storage_offset_bytes(t) atIndex:(NSUInteger)idx];
140
+ };
141
+ set_tensor(gy2d, 0);
142
+ set_tensor(weight, 1);
143
+ set_tensor(active_mask, 2);
144
+ set_tensor(grad_x, 3);
145
+
146
+ SparseLinearParams prm{(uint32_t)N, (uint32_t)In, (uint32_t)Out, 0};[encoder setBytes:&prm length:sizeof(SparseLinearParams) atIndex:4];
147
+
148
+ MTLSize tg = MTLSizeMake(16, 16, 1);
149
+ MTLSize grid = MTLSizeMake((NSUInteger)((In + 15) / 16), (NSUInteger)((N + 15) / 16), 1);
150
+ [encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg];
151
+
152
+ return grad_x;
153
+ }
154
+
155
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
156
+ m.def("sparse_linear_grad_wb", &sparse_linear_grad_wb, "Sparse active-row Linear dW/db (Metal/MPS, fp32)");
157
+ m.def("sparse_linear_grad_x", &sparse_linear_grad_x, "Sparse active-row Linear dX (Metal/MPS, fp32)");
158
+ }
experiments/sparse_linear_v10_metal/sparse_transformer_v10.py ADDED
@@ -0,0 +1,1112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sparse Transformer v10: empirical sparse-backward optimization benchmark with optional Metal kernels.
3
+
4
+ v8 proved that the row-sparse mask can be moved into a custom Linear backward.
5
+ v10 keeps the no-audit training path and adds wall-clock benchmarking plus an optional Metal active-row Linear backward kernel.
6
+
7
+ Default behavior
8
+ ----------------
9
+ 1. Run a short dense warmup, usually 5 steps.
10
+ 2. Initialize the EMA row-importance predictor from those dense warmup gradients.
11
+ 3. After warmup, choose active rows from the predictor.
12
+ 4. Train using sparse backward.
13
+ 5. Update EMA statistics only from rows that were actually active/observed.
14
+ 6. Do not compute dense gradients unless --audit_every > 0.
15
+
16
+ Audit behavior
17
+ --------------
18
+ --audit_every 0
19
+ No dense audit after warmup. Cosine/Jaccard/top20 are unavailable and show as nan.
20
+
21
+ --audit_every N
22
+ Every N steps, run an extra dense backward pass on the same batch only to
23
+ measure cosine/top20/Jaccard. The audit is NOT used to update the selector,
24
+ except for oracle_current, which is explicitly an upper-bound control.
25
+
26
+ This is still not a wall-clock benchmark on vanilla PyTorch/MPS/CPU. The custom
27
+ backward uses indexing and ordinary PyTorch matmuls. The goal is to verify that
28
+ the method survives without dense information after warmup.
29
+
30
+ Examples
31
+ --------
32
+ No-audit practical run:
33
+ python3 sparse_transformer_v9.py \
34
+ --device mps \
35
+ --steps 2000 \
36
+ --active_fractions 0.05 0.02 \
37
+ --warmup_steps_list 5 \
38
+ --policies predicted_magnitude random \
39
+ --backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
40
+ --audit_every 0
41
+
42
+ Occasional audit for measurement only:
43
+ python3 sparse_transformer_v9.py \
44
+ --steps 2000 \
45
+ --active_fractions 0.05 0.02 \
46
+ --warmup_steps_list 5 \
47
+ --policies predicted_magnitude random \
48
+ --backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
49
+ --audit_every 100
50
+ """
51
+
52
+
53
+ from __future__ import annotations
54
+
55
+ import argparse
56
+ import math
57
+ import random
58
+ import time
59
+ from typing import Dict, List, Literal, Optional, Tuple
60
+
61
+ import torch
62
+
63
+ torch.set_num_threads(1)
64
+ import torch.nn as nn
65
+ import torch.nn.functional as F
66
+
67
+ try:
68
+ import sparse_linear_metal # built by setup.py in this folder
69
+ except Exception:
70
+ sparse_linear_metal = None
71
+
72
+ USE_METAL_KERNEL = False
73
+
74
+ def sync_device(device: str) -> None:
75
+ if device == "cuda" and torch.cuda.is_available():
76
+ torch.cuda.synchronize()
77
+ elif device == "mps" and hasattr(torch, "mps"):
78
+ torch.mps.synchronize()
79
+
80
+ Policy = Literal["predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"]
81
+ BackwardMode = Literal["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
82
+
83
+
84
+ # -----------------------------
85
+ # Reproducibility and device
86
+ # -----------------------------
87
+
88
+ def set_seed(seed: int) -> None:
89
+ random.seed(seed)
90
+ torch.manual_seed(seed)
91
+ if torch.cuda.is_available():
92
+ torch.cuda.manual_seed_all(seed)
93
+
94
+
95
+ def default_device() -> str:
96
+ if torch.cuda.is_available():
97
+ return "cuda"
98
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
99
+ return "mps"
100
+ return "cpu"
101
+
102
+
103
+ def make_cpu_generator(seed: int) -> torch.Generator:
104
+ gen = torch.Generator(device="cpu")
105
+ gen.manual_seed(seed)
106
+ return gen
107
+
108
+
109
+ # -----------------------------
110
+ # Data
111
+ # -----------------------------
112
+
113
+ def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
114
+ rng = random.Random(seed)
115
+ names = ["ada", "turing", "grace", "lovelace", "noether", "shannon", "hopper", "gauss"]
116
+ verbs = ["builds", "tests", "traces", "compresses", "predicts", "routes", "writes", "measures"]
117
+ objects = ["signals", "gradients", "tokens", "circuits", "features", "masks", "errors", "states"]
118
+ adverbs = ["quietly", "boldly", "slowly", "quickly", "cleanly", "strangely", "carefully"]
119
+ clauses = [
120
+ "when the loss falls",
121
+ "after the mask shifts",
122
+ "before the model answers",
123
+ "while the signal drifts",
124
+ "if the pattern repeats",
125
+ "because the tail is noisy",
126
+ ]
127
+ symbols = ["alpha", "beta", "gamma", "delta", "omega", "sigma"]
128
+
129
+ lines: List[str] = []
130
+ for _ in range(n_sentences):
131
+ t = rng.randrange(6)
132
+ if t == 0:
133
+ line = f"{rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} {rng.choice(adverbs)}."
134
+ elif t == 1:
135
+ line = f"{rng.choice(clauses)}, {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)}."
136
+ elif t == 2:
137
+ a, b = rng.sample(symbols, 2)
138
+ line = f"rule {a}: {rng.choice(objects)} -> {rng.choice(objects)}; rule {b}: {rng.choice(objects)} -> {rng.choice(objects)}."
139
+ elif t == 3:
140
+ line = f"the {rng.choice(objects)} {rng.choice(verbs)} the {rng.choice(objects)} {rng.choice(adverbs)}."
141
+ elif t == 4:
142
+ seq = " ".join(rng.choice(symbols) for _ in range(rng.randint(2, 7)))
143
+ line = f"sequence {seq} ends when {rng.choice(names)} {rng.choice(verbs)}."
144
+ else:
145
+ line = f"if {rng.choice(objects)} rise then {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} else wait."
146
+ lines.append(line)
147
+ return "\n".join(lines) + "\n"
148
+
149
+
150
+ class CharCorpus:
151
+ def __init__(self, text: str, block_size: int, device: str):
152
+ chars = sorted(set(text))
153
+ self.stoi = {ch: i for i, ch in enumerate(chars)}
154
+ self.itos = {i: ch for ch, i in self.stoi.items()}
155
+ self.vocab_size = len(chars)
156
+ self.block_size = block_size
157
+ self.device = device
158
+
159
+ data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
160
+ split = int(0.9 * len(data))
161
+ self.train_data = data[:split]
162
+ self.val_data = data[split:]
163
+
164
+ def get_batch(
165
+ self,
166
+ split: str,
167
+ batch_size: int,
168
+ generator: Optional[torch.Generator] = None,
169
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
170
+ data = self.train_data if split == "train" else self.val_data
171
+ max_start = len(data) - self.block_size - 1
172
+ if max_start <= 0:
173
+ raise ValueError("Corpus too small for block_size")
174
+ ix = torch.randint(max_start, (batch_size,), generator=generator)
175
+ x = torch.stack([data[i : i + self.block_size] for i in ix])
176
+ y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
177
+ return x.to(self.device), y.to(self.device)
178
+
179
+
180
+ def load_text(args: argparse.Namespace) -> str:
181
+ if args.text_path:
182
+ with open(args.text_path, "r", encoding="utf-8") as f:
183
+ return f.read()
184
+ return make_synthetic_corpus(args.synthetic_sentences, args.seed)
185
+
186
+
187
+ # -----------------------------
188
+ # Sparse Linear autograd
189
+ # -----------------------------
190
+
191
+ class MaskedLinearFunction(torch.autograd.Function):
192
+ @staticmethod
193
+ def forward( # type: ignore[override]
194
+ ctx,
195
+ x: torch.Tensor,
196
+ weight: torch.Tensor,
197
+ bias: Optional[torch.Tensor],
198
+ active_rows: torch.Tensor,
199
+ sparse_dx: bool,
200
+ ) -> torch.Tensor:
201
+ ctx.save_for_backward(x, weight, active_rows)
202
+ ctx.has_bias = bias is not None
203
+ ctx.sparse_dx = bool(sparse_dx)
204
+ return F.linear(x, weight, bias)
205
+
206
+ @staticmethod
207
+ def backward(ctx, grad_y: torch.Tensor): # type: ignore[override]
208
+ x, weight, active_rows = ctx.saved_tensors
209
+ sparse_dx = bool(ctx.sparse_dx)
210
+ has_bias = bool(ctx.has_bias)
211
+
212
+ x_shape = x.shape
213
+ x_flat = x.reshape(-1, x.shape[-1]).contiguous()
214
+ gy_flat = grad_y.reshape(-1, grad_y.shape[-1]).contiguous()
215
+
216
+ active_idx = torch.nonzero(active_rows, as_tuple=False).flatten().to(dtype=torch.long).contiguous()
217
+
218
+ can_use_metal = (
219
+ USE_METAL_KERNEL
220
+ and sparse_linear_metal is not None
221
+ and x.device.type == "mps"
222
+ and weight.device.type == "mps"
223
+ and x.dtype == torch.float32
224
+ and weight.dtype == torch.float32
225
+ and gy_flat.dtype == torch.float32
226
+ and x_flat.is_contiguous()
227
+ and gy_flat.is_contiguous()
228
+ and weight.is_contiguous()
229
+ )
230
+
231
+ if can_use_metal:
232
+ grad_weight, grad_bias_full = sparse_linear_metal.sparse_linear_grad_wb(
233
+ x_flat, gy_flat, active_idx, int(weight.shape[0])
234
+ )
235
+ grad_bias = grad_bias_full if has_bias else None
236
+ if sparse_dx:
237
+ grad_x_flat = sparse_linear_metal.sparse_linear_grad_x(gy_flat, weight.contiguous(), active_idx)
238
+ else:
239
+ grad_x_flat = gy_flat @ weight
240
+ grad_x = grad_x_flat.reshape(x_shape)
241
+ return grad_x, grad_weight, grad_bias, None, None
242
+
243
+ # Portable PyTorch fallback. This is mathematically identical but usually
244
+ # not an actual wall-clock optimization on MPS/CPU because indexing and
245
+ # general matmul overhead dominate.
246
+ grad_weight = torch.zeros_like(weight)
247
+ grad_bias = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if has_bias else None
248
+
249
+ if active_idx.numel() > 0:
250
+ gy_active = gy_flat[:, active_idx]
251
+ grad_weight[active_idx] = gy_active.transpose(0, 1) @ x_flat
252
+ if grad_bias is not None:
253
+ grad_bias[active_idx] = gy_active.sum(dim=0)
254
+
255
+ if sparse_dx:
256
+ grad_x_flat = gy_active @ weight[active_idx]
257
+ else:
258
+ grad_x_flat = gy_flat @ weight
259
+ else:
260
+ if sparse_dx:
261
+ grad_x_flat = torch.zeros_like(x_flat)
262
+ else:
263
+ grad_x_flat = gy_flat @ weight
264
+
265
+ grad_x = grad_x_flat.reshape(x_shape)
266
+ return grad_x, grad_weight, grad_bias, None, None
267
+
268
+
269
+ class SparseLinear(nn.Linear):
270
+ """nn.Linear with an optional row-sparse backward pass."""
271
+
272
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
273
+ super().__init__(in_features, out_features, bias=bias)
274
+ self.sparse_enabled = False
275
+ self.sparse_dx = False
276
+ self.active_rows: Optional[torch.Tensor] = None
277
+
278
+ def set_sparse_backward(self, enabled: bool, active_rows: Optional[torch.Tensor], sparse_dx: bool) -> None:
279
+ self.sparse_enabled = bool(enabled)
280
+ self.sparse_dx = bool(sparse_dx)
281
+ self.active_rows = active_rows
282
+
283
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
284
+ if not self.sparse_enabled or self.active_rows is None:
285
+ return F.linear(x, self.weight, self.bias)
286
+ return MaskedLinearFunction.apply(x, self.weight, self.bias, self.active_rows, self.sparse_dx)
287
+
288
+
289
+ # -----------------------------
290
+ # Mini GPT
291
+ # -----------------------------
292
+
293
+ class CausalSelfAttention(nn.Module):
294
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
295
+ super().__init__()
296
+ assert n_embd % n_head == 0
297
+ self.n_head = n_head
298
+ self.head_dim = n_embd // n_head
299
+ self.c_attn = SparseLinear(n_embd, 3 * n_embd)
300
+ self.c_proj = SparseLinear(n_embd, n_embd)
301
+ self.dropout = nn.Dropout(dropout)
302
+ self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
303
+
304
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
305
+ B, T, C = x.shape
306
+ qkv = self.c_attn(x)
307
+ q, k, v = qkv.split(C, dim=2)
308
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
309
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
310
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
311
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
312
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
313
+ att = F.softmax(att, dim=-1)
314
+ att = self.dropout(att)
315
+ y = att @ v
316
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
317
+ return self.c_proj(y)
318
+
319
+
320
+ class FeedForward(nn.Module):
321
+ def __init__(self, n_embd: int, dropout: float):
322
+ super().__init__()
323
+ self.c_fc = SparseLinear(n_embd, 4 * n_embd)
324
+ self.c_proj = SparseLinear(4 * n_embd, n_embd)
325
+ self.dropout = nn.Dropout(dropout)
326
+
327
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
328
+ return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
329
+
330
+
331
+ class Block(nn.Module):
332
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
333
+ super().__init__()
334
+ self.ln1 = nn.LayerNorm(n_embd)
335
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
336
+ self.ln2 = nn.LayerNorm(n_embd)
337
+ self.mlp = FeedForward(n_embd, dropout)
338
+
339
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
340
+ x = x + self.attn(self.ln1(x))
341
+ x = x + self.mlp(self.ln2(x))
342
+ return x
343
+
344
+
345
+ class MiniGPT(nn.Module):
346
+ def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float):
347
+ super().__init__()
348
+ self.block_size = block_size
349
+ self.tok_emb = nn.Embedding(vocab_size, n_embd)
350
+ self.pos_emb = nn.Embedding(block_size, n_embd)
351
+ self.drop = nn.Dropout(dropout)
352
+ self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
353
+ self.ln_f = nn.LayerNorm(n_embd)
354
+ self.lm_head = SparseLinear(n_embd, vocab_size)
355
+
356
+ def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
357
+ B, T = idx.shape
358
+ pos = torch.arange(T, device=idx.device)
359
+ x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
360
+ x = self.drop(x)
361
+ x = self.blocks(x)
362
+ x = self.ln_f(x)
363
+ logits = self.lm_head(x)
364
+ loss = None
365
+ if targets is not None:
366
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
367
+ return logits, loss
368
+
369
+
370
+ def named_sparse_linear_modules(model: nn.Module) -> List[Tuple[str, SparseLinear]]:
371
+ return [(name, m) for name, m in model.named_modules() if isinstance(m, SparseLinear)]
372
+
373
+
374
+ def parameter_fractions(model: nn.Module) -> Tuple[int, int, float]:
375
+ total = sum(p.numel() for p in model.parameters())
376
+ linear = 0
377
+ for _, m in named_sparse_linear_modules(model):
378
+ linear += m.weight.numel()
379
+ if m.bias is not None:
380
+ linear += m.bias.numel()
381
+ return total, linear, linear / max(1, total)
382
+
383
+
384
+ def configure_sparse_linears(
385
+ model: nn.Module,
386
+ masker: Optional["RowMasker"],
387
+ enabled: bool,
388
+ backward_mode: Optional[str],
389
+ ) -> None:
390
+ sparse_dx = backward_mode == "sparse_dW_sparse_dX"
391
+ for _, m in named_sparse_linear_modules(model):
392
+ active = masker.row_mask_for(m) if masker is not None else None
393
+ m.set_sparse_backward(enabled=enabled, active_rows=active, sparse_dx=sparse_dx)
394
+
395
+
396
+ # -----------------------------
397
+ # Mask selector
398
+ # -----------------------------
399
+
400
+ class RowMasker:
401
+ def __init__(
402
+ self,
403
+ model: nn.Module,
404
+ policy: Policy,
405
+ active_fraction: float,
406
+ explore_fraction: float,
407
+ mass_beta: float,
408
+ unobserved_decay: float,
409
+ warmup_steps: int,
410
+ ucb_alpha: float,
411
+ mass_init: float,
412
+ device: str,
413
+ ):
414
+ self.model = model
415
+ self.policy = policy
416
+ self.active_fraction = active_fraction
417
+ self.explore_fraction = explore_fraction
418
+ self.mass_beta = mass_beta
419
+ self.unobserved_decay = unobserved_decay
420
+ self.warmup_steps = warmup_steps
421
+ self.ucb_alpha = ucb_alpha
422
+ self.mass_init = mass_init
423
+ self.device = device
424
+ self.step_index = 0
425
+
426
+ self.linear_modules = [m for _, m in named_sparse_linear_modules(model)]
427
+ self.module_to_ids: Dict[SparseLinear, torch.Tensor] = {}
428
+ ids = []
429
+ offset = 0
430
+ for m in self.linear_modules:
431
+ n = m.weight.shape[0]
432
+ block_ids = torch.arange(offset, offset + n, device=device)
433
+ self.module_to_ids[m] = block_ids
434
+ ids.append(block_ids)
435
+ offset += n
436
+ self.n_blocks = offset
437
+
438
+ self.predicted_mass = torch.full((self.n_blocks,), mass_init, device=device)
439
+ self.last_full_mass = torch.full((self.n_blocks,), mass_init, device=device)
440
+ self.observed_count = torch.zeros(self.n_blocks, device=device)
441
+ self.global_mass_ema = torch.tensor(max(mass_init, 1e-6), device=device)
442
+
443
+ self.prev_active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device)
444
+ self.active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device)
445
+ self.row_masks: Dict[SparseLinear, torch.Tensor] = {
446
+ m: torch.zeros(m.weight.shape[0], dtype=torch.bool, device=device) for m in self.linear_modules
447
+ }
448
+
449
+ def _topk_mask(self, values: torch.Tensor, fraction: float) -> torch.Tensor:
450
+ k = max(1, int(fraction * values.numel()))
451
+ mask = torch.zeros_like(values, dtype=torch.bool)
452
+ noisy = values + 1e-9 * torch.rand_like(values)
453
+ mask[torch.topk(noisy, k=k).indices] = True
454
+ return mask
455
+
456
+ @staticmethod
457
+ def _jaccard(a: torch.Tensor, b: torch.Tensor) -> float:
458
+ inter = (a & b).sum().float()
459
+ union = (a | b).sum().float()
460
+ return float((inter / torch.clamp(union, min=1.0)).item())
461
+
462
+ def _set_active(self, active: torch.Tensor) -> None:
463
+ self.active = active
464
+ self.row_masks = {}
465
+ for m, ids in self.module_to_ids.items():
466
+ self.row_masks[m] = active[ids]
467
+
468
+ def _sample_exploit_explore(self, scores: torch.Tensor) -> torch.Tensor:
469
+ n = self.n_blocks
470
+ k_total = max(1, int(self.active_fraction * n))
471
+ k_explore = min(k_total, max(0, int(self.explore_fraction * k_total)))
472
+ k_exploit = k_total - k_explore
473
+ active = torch.zeros(n, dtype=torch.bool, device=self.device)
474
+
475
+ if k_exploit > 0:
476
+ active[torch.topk(scores + 1e-9 * torch.rand_like(scores), k=k_exploit).indices] = True
477
+ if k_explore > 0:
478
+ remaining = torch.nonzero(~active, as_tuple=False).flatten()
479
+ pick = remaining[torch.randperm(remaining.numel(), device=self.device)[:k_explore]]
480
+ active[pick] = True
481
+ return active
482
+
483
+ def choose_pre_backward(self, step: int) -> None:
484
+ self.step_index = step
485
+ if step < self.warmup_steps:
486
+ self._set_active(torch.ones(self.n_blocks, dtype=torch.bool, device=self.device))
487
+ return
488
+
489
+ if self.policy == "oracle_current":
490
+ # Oracle cannot choose until the dense audit gradient is known.
491
+ self._set_active(torch.zeros(self.n_blocks, dtype=torch.bool, device=self.device))
492
+ return
493
+
494
+ if self.policy == "random":
495
+ self._set_active(self._sample_exploit_explore(torch.rand(self.n_blocks, device=self.device)))
496
+ return
497
+
498
+ if self.policy == "stale_current":
499
+ self._set_active(self._topk_mask(self.last_full_mass, self.active_fraction))
500
+ return
501
+
502
+ if self.policy == "predicted_magnitude":
503
+ self._set_active(self._sample_exploit_explore(self.predicted_mass))
504
+ return
505
+
506
+ if self.policy == "ucb_magnitude":
507
+ t = max(1, step - self.warmup_steps + 1)
508
+ log_term = torch.log(torch.tensor(float(t + 2), device=self.device))
509
+ bonus_scale = torch.clamp(self.global_mass_ema, min=1e-8)
510
+ bonus = self.ucb_alpha * bonus_scale * torch.sqrt(log_term / (self.observed_count + 1.0))
511
+ self._set_active(self._sample_exploit_explore(self.predicted_mass + bonus))
512
+ return
513
+
514
+ raise ValueError(f"Unknown policy: {self.policy}")
515
+
516
+ @torch.no_grad()
517
+ def current_gradient_mass_from_grads(self) -> torch.Tensor:
518
+ mass = torch.zeros(self.n_blocks, device=self.device)
519
+ for m, ids in self.module_to_ids.items():
520
+ if m.weight.grad is None:
521
+ continue
522
+ row_sq = m.weight.grad.square().sum(dim=1)
523
+ if m.bias is not None and m.bias.grad is not None:
524
+ row_sq = row_sq + m.bias.grad.square()
525
+ mass[ids] = torch.sqrt(row_sq + 1e-30)
526
+ return mass
527
+
528
+ @torch.no_grad()
529
+ @torch.no_grad()
530
+ def update_predictor_from_observed_mass(self, mass: torch.Tensor, observed: Optional[torch.Tensor] = None) -> Dict[str, float]:
531
+ """Update EMA statistics only for observed rows.
532
+
533
+ After warmup, sparse backward only gives trustworthy gradients for active
534
+ rows, so only those rows are allowed to update predicted_mass.
535
+ """
536
+ if observed is None:
537
+ observed = self.active
538
+
539
+ new_active = observed & (self.observed_count == 0)
540
+ self.predicted_mass.mul_(self.unobserved_decay)
541
+
542
+ if bool(observed.any().item()):
543
+ obs_mass = mass[observed]
544
+ first_seen = self.observed_count[observed] == 0
545
+ ema_mass = self.mass_beta * self.predicted_mass[observed] + (1.0 - self.mass_beta) * obs_mass
546
+ self.predicted_mass[observed] = torch.where(first_seen, obs_mass, ema_mass)
547
+ self.observed_count[observed] += 1.0
548
+ self.global_mass_ema = self.mass_beta * self.global_mass_ema + (1.0 - self.mass_beta) * obs_mass.mean()
549
+
550
+ stability = self._jaccard(self.active, self.prev_active)
551
+ self.prev_active = self.active.clone()
552
+
553
+ return {
554
+ "stability": stability,
555
+ "active_fraction_real": float(self.active.float().mean().item()),
556
+ "coverage": float((self.observed_count > 0).float().mean().item()),
557
+ "avg_obs_count": float(self.observed_count.mean().item()),
558
+ "new_active_fraction": float(new_active.float().mean().item()),
559
+ }
560
+
561
+ @torch.no_grad()
562
+ def audit_metrics_from_mass(self, mass: torch.Tensor) -> Dict[str, float]:
563
+ """Compute dense-audit metrics without updating the practical selector."""
564
+ active = self.active
565
+ true_sq = mass.square().sum()
566
+ approx_sq = mass[active].square().sum()
567
+ cosine = float((torch.sqrt(approx_sq + 1e-30) / torch.sqrt(true_sq + 1e-30)).item())
568
+
569
+ oracle_mask = self._topk_mask(mass, self.active_fraction)
570
+ jacc = self._jaccard(active, oracle_mask)
571
+
572
+ k20 = max(1, int(0.2 * self.n_blocks))
573
+ sorted_mass = torch.sort(mass, descending=True).values
574
+ top20_mass = float((sorted_mass[:k20].sum() / (sorted_mass.sum() + 1e-12)).item())
575
+
576
+ return {
577
+ "cosine": cosine,
578
+ "norm_ratio": cosine,
579
+ "top20_mass": top20_mass,
580
+ "jacc_oracle": jacc,
581
+ }
582
+
583
+ def audit_and_update_from_mass(self, step: int, mass: torch.Tensor) -> Dict[str, float]:
584
+ if step < self.warmup_steps:
585
+ active = torch.ones(self.n_blocks, dtype=torch.bool, device=self.device)
586
+ self._set_active(active)
587
+ elif self.policy == "oracle_current":
588
+ active = self._topk_mask(mass, self.active_fraction)
589
+ self._set_active(active)
590
+ else:
591
+ active = self.active
592
+
593
+ true_sq = mass.square().sum()
594
+ approx_sq = mass[active].square().sum()
595
+ cosine = float((torch.sqrt(approx_sq + 1e-30) / torch.sqrt(true_sq + 1e-30)).item())
596
+
597
+ oracle_mask = self._topk_mask(mass, self.active_fraction)
598
+ jacc = self._jaccard(active, oracle_mask)
599
+ stability = self._jaccard(active, self.prev_active)
600
+ self.prev_active = active.clone()
601
+
602
+ k20 = max(1, int(0.2 * self.n_blocks))
603
+ sorted_mass = torch.sort(mass, descending=True).values
604
+ top20_mass = float((sorted_mass[:k20].sum() / (sorted_mass.sum() + 1e-12)).item())
605
+
606
+ new_active = active & (self.observed_count == 0)
607
+
608
+ # Practical rule: update predicted statistics only for active/observed rows.
609
+ self.predicted_mass.mul_(self.unobserved_decay)
610
+ observed = active
611
+ if bool(observed.any().item()):
612
+ obs_mass = mass[observed]
613
+ first_seen = self.observed_count[observed] == 0
614
+ ema_mass = self.mass_beta * self.predicted_mass[observed] + (1.0 - self.mass_beta) * obs_mass
615
+ self.predicted_mass[observed] = torch.where(first_seen, obs_mass, ema_mass)
616
+ self.observed_count[observed] += 1.0
617
+ self.global_mass_ema = self.mass_beta * self.global_mass_ema + (1.0 - self.mass_beta) * obs_mass.mean()
618
+
619
+ # Dense audit signal; only stale_current is allowed to use this for selection.
620
+ self.last_full_mass = mass.detach().clone()
621
+
622
+ return {
623
+ "cosine": cosine,
624
+ "norm_ratio": cosine,
625
+ "top20_mass": top20_mass,
626
+ "jacc_oracle": jacc,
627
+ "stability": stability,
628
+ "active_fraction_real": float(active.float().mean().item()),
629
+ "coverage": float((self.observed_count > 0).float().mean().item()),
630
+ "avg_obs_count": float(self.observed_count.mean().item()),
631
+ "new_active_fraction": float(new_active.float().mean().item()),
632
+ }
633
+
634
+ def row_mask_for(self, module: SparseLinear) -> Optional[torch.Tensor]:
635
+ return self.row_masks.get(module)
636
+
637
+
638
+ # -----------------------------
639
+ # Masked Adam
640
+ # -----------------------------
641
+
642
+ class MaskedAdam:
643
+ def __init__(
644
+ self,
645
+ model: nn.Module,
646
+ masker: Optional[RowMasker],
647
+ lr: float,
648
+ betas=(0.9, 0.95),
649
+ eps=1e-8,
650
+ weight_decay=0.0,
651
+ freeze_non_linear_when_sparse: bool = False,
652
+ ):
653
+ self.model = model
654
+ self.masker = masker
655
+ self.lr = lr
656
+ self.beta1, self.beta2 = betas
657
+ self.eps = eps
658
+ self.weight_decay = weight_decay
659
+ self.freeze_non_linear_when_sparse = freeze_non_linear_when_sparse
660
+ self.state: Dict[nn.Parameter, Dict[str, torch.Tensor]] = {}
661
+ self.linear_param: Dict[nn.Parameter, Tuple[SparseLinear, str]] = {}
662
+ for _, m in named_sparse_linear_modules(model):
663
+ self.linear_param[m.weight] = (m, "weight")
664
+ if m.bias is not None:
665
+ self.linear_param[m.bias] = (m, "bias")
666
+
667
+ def zero_grad(self) -> None:
668
+ for p in self.model.parameters():
669
+ p.grad = None
670
+
671
+ @torch.no_grad()
672
+ def step(self) -> None:
673
+ for p in self.model.parameters():
674
+ if p.grad is None:
675
+ continue
676
+ if self.masker is not None and self.freeze_non_linear_when_sparse and p not in self.linear_param:
677
+ continue
678
+
679
+ if p not in self.state:
680
+ self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
681
+ m = self.state[p]["m"]
682
+ v = self.state[p]["v"]
683
+ g = p.grad
684
+ if self.weight_decay:
685
+ g = g.add(p, alpha=self.weight_decay)
686
+
687
+ row_mask = None
688
+ if self.masker is not None and p in self.linear_param:
689
+ module, kind = self.linear_param[p]
690
+ base = self.masker.row_mask_for(module)
691
+ if base is not None:
692
+ row_mask = base.view(-1, *([1] * (p.ndim - 1))) if kind == "weight" else base
693
+
694
+ if row_mask is None:
695
+ m.mul_(self.beta1).add_(g, alpha=1.0 - self.beta1)
696
+ v.mul_(self.beta2).addcmul_(g, g, value=1.0 - self.beta2)
697
+ p.add_(m / (torch.sqrt(v) + self.eps), alpha=-self.lr)
698
+ else:
699
+ # MPS can mis-handle expanded boolean masks for row-wise assignment
700
+ # (e.g. reporting nonsense out-of-bounds indices). Use explicit
701
+ # row indices and index_copy_ instead. This also avoids materializing
702
+ # a full expanded mask for weight matrices.
703
+ active_rows = row_mask.reshape(-1).nonzero(as_tuple=False).flatten()
704
+ if active_rows.numel() == 0:
705
+ continue
706
+
707
+ m_rows = m.index_select(0, active_rows)
708
+ v_rows = v.index_select(0, active_rows)
709
+ g_rows = g.index_select(0, active_rows)
710
+
711
+ new_m_rows = self.beta1 * m_rows + (1.0 - self.beta1) * g_rows
712
+ new_v_rows = self.beta2 * v_rows + (1.0 - self.beta2) * g_rows * g_rows
713
+ update_rows = new_m_rows / (torch.sqrt(new_v_rows) + self.eps)
714
+ p_rows = p.index_select(0, active_rows) - self.lr * update_rows
715
+
716
+ m.index_copy_(0, active_rows, new_m_rows)
717
+ v.index_copy_(0, active_rows, new_v_rows)
718
+ p.index_copy_(0, active_rows, p_rows)
719
+
720
+
721
+ # -----------------------------
722
+ # Training utilities
723
+ # -----------------------------
724
+
725
+ @torch.no_grad()
726
+ def estimate_loss(model: nn.Module, corpus: CharCorpus, batch_size: int, eval_iters: int, seed: int) -> Dict[str, float]:
727
+ model.eval()
728
+ configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
729
+ out = {}
730
+ for split in ["train", "val"]:
731
+ losses = []
732
+ gen = make_cpu_generator(seed + (0 if split == "train" else 100000))
733
+ for _ in range(eval_iters):
734
+ x, y = corpus.get_batch(split, batch_size, generator=gen)
735
+ _, loss = model(x, y)
736
+ losses.append(float(loss.item()))
737
+ out[split] = sum(losses) / len(losses)
738
+ model.train()
739
+ return out
740
+
741
+
742
+ def dense_audit_pass(model: nn.Module, corpus_batch: Tuple[torch.Tensor, torch.Tensor], opt: MaskedAdam, masker: RowMasker) -> torch.Tensor:
743
+ x, y = corpus_batch
744
+ configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
745
+ opt.zero_grad()
746
+ _, audit_loss = model(x, y)
747
+ audit_loss.backward()
748
+ mass = masker.current_gradient_mass_from_grads()
749
+ opt.zero_grad()
750
+ return mass
751
+
752
+
753
+ def sparse_training_backward(
754
+ model: nn.Module,
755
+ corpus_batch: Tuple[torch.Tensor, torch.Tensor],
756
+ opt: MaskedAdam,
757
+ masker: Optional[RowMasker],
758
+ backward_mode: Optional[BackwardMode],
759
+ ) -> float:
760
+ x, y = corpus_batch
761
+ opt.zero_grad()
762
+
763
+ if masker is None or backward_mode is None or backward_mode == "masked_optimizer":
764
+ configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
765
+ else:
766
+ configure_sparse_linears(model, masker=masker, enabled=True, backward_mode=backward_mode)
767
+
768
+ _, loss = model(x, y)
769
+ loss.backward()
770
+ configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
771
+ return float(loss.item())
772
+
773
+
774
+ def train_run(
775
+ corpus: CharCorpus,
776
+ args: argparse.Namespace,
777
+ policy: Optional[Policy],
778
+ backward_mode: Optional[BackwardMode],
779
+ active_fraction: float,
780
+ warmup_steps: int,
781
+ explore_fraction: float,
782
+ seed_offset: int,
783
+ ) -> Dict[str, float | str]:
784
+ # Same model initialization and same minibatch sequence for every run by default.
785
+ set_seed(args.seed + (seed_offset if args.unpaired_seeds else 0))
786
+ data_gen = make_cpu_generator(args.seed + 12345)
787
+
788
+ dev = corpus.device
789
+ model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev)
790
+
791
+ masker = None
792
+ if policy is not None:
793
+ masker = RowMasker(
794
+ model=model,
795
+ policy=policy,
796
+ active_fraction=active_fraction,
797
+ explore_fraction=explore_fraction,
798
+ mass_beta=args.mass_beta,
799
+ unobserved_decay=args.unobserved_decay,
800
+ warmup_steps=warmup_steps,
801
+ ucb_alpha=args.ucb_alpha,
802
+ mass_init=args.mass_init,
803
+ device=dev,
804
+ )
805
+ opt = MaskedAdam(
806
+ model,
807
+ masker,
808
+ lr=args.lr,
809
+ weight_decay=args.weight_decay,
810
+ freeze_non_linear_when_sparse=args.freeze_non_linear_when_sparse,
811
+ )
812
+
813
+ sums = {
814
+ "cosine": 0.0,
815
+ "norm_ratio": 0.0,
816
+ "top20_mass": 0.0,
817
+ "jacc_oracle": 0.0,
818
+ "stability": 0.0,
819
+ "active_fraction_real": 0.0,
820
+ "coverage": 0.0,
821
+ "avg_obs_count": 0.0,
822
+ "new_active_fraction": 0.0,
823
+ }
824
+ counts = {k: 0 for k in sums}
825
+
826
+ def add_metrics(metrics: Dict[str, float]) -> None:
827
+ for k, v in metrics.items():
828
+ if k in sums:
829
+ sums[k] += float(v)
830
+ counts[k] += 1
831
+
832
+ if args.benchmark_sync:
833
+ sync_device(dev)
834
+ t0 = time.perf_counter()
835
+
836
+ for step in range(args.steps):
837
+ batch = corpus.get_batch("train", args.batch_size, generator=data_gen)
838
+
839
+ if masker is None:
840
+ loss_value = sparse_training_backward(model, batch, opt, masker=None, backward_mode=None)
841
+ opt.step()
842
+ else:
843
+ if step < warmup_steps:
844
+ # Dense bootstrap. Every row is active and every row updates the predictor.
845
+ masker._set_active(torch.ones(masker.n_blocks, dtype=torch.bool, device=dev))
846
+ loss_value = sparse_training_backward(model, batch, opt, masker=masker, backward_mode="masked_optimizer")
847
+ full_mass = masker.current_gradient_mass_from_grads()
848
+ masker.last_full_mass = full_mass.detach().clone()
849
+ add_metrics(masker.audit_metrics_from_mass(full_mass))
850
+ add_metrics(masker.update_predictor_from_observed_mass(full_mass, observed=masker.active))
851
+ opt.step()
852
+ else:
853
+ masker.choose_pre_backward(step)
854
+
855
+ if policy == "oracle_current":
856
+ # Explicit upper bound. Oracle necessarily computes dense gradients to choose rows.
857
+ full_mass = dense_audit_pass(model, batch, opt, masker)
858
+ masker._set_active(masker._topk_mask(full_mass, active_fraction))
859
+ masker.last_full_mass = full_mass.detach().clone()
860
+ add_metrics(masker.audit_metrics_from_mass(full_mass))
861
+ elif args.audit_every > 0 and ((step - warmup_steps) % args.audit_every == 0):
862
+ # Measurement only. Do not update predicted_magnitude/ucb/random with this dense mass.
863
+ full_mass = dense_audit_pass(model, batch, opt, masker)
864
+ add_metrics(masker.audit_metrics_from_mass(full_mass))
865
+ if policy == "stale_current":
866
+ masker.last_full_mass = full_mass.detach().clone()
867
+
868
+ loss_value = sparse_training_backward(model, batch, opt, masker=masker, backward_mode=backward_mode)
869
+
870
+ # Practical selector update: only active rows were observed by the training backward pass.
871
+ observed_mass = masker.current_gradient_mass_from_grads()
872
+ add_metrics(masker.update_predictor_from_observed_mass(observed_mass, observed=masker.active))
873
+ opt.step()
874
+
875
+ if args.benchmark_sync:
876
+ sync_device(dev)
877
+
878
+ if args.verbose and (step % args.eval_interval == 0 or step == args.steps - 1):
879
+ losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters, seed=args.seed + 555)
880
+ name = "dense" if policy is None else f"{policy}/{backward_mode}"
881
+ print(
882
+ f"{name:38s} step={step:5d} warm={warmup_steps:4d} explore={explore_fraction:.2f} "
883
+ f"loss={loss_value:.4f} train={losses['train']:.4f} val={losses['val']:.4f}"
884
+ )
885
+
886
+ if args.benchmark_sync:
887
+ sync_device(dev)
888
+ elapsed_s = time.perf_counter() - t0
889
+ train_tokens = float(args.steps * args.batch_size * args.block_size)
890
+ step_ms = 1000.0 * elapsed_s / max(1, args.steps)
891
+ tokens_per_s = train_tokens / max(elapsed_s, 1e-9)
892
+
893
+ losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters, seed=args.seed + 999)
894
+ row: Dict[str, float | str] = {
895
+ "run": "dense_baseline" if policy is None else policy,
896
+ "mode": "dense" if backward_mode is None else backward_mode,
897
+ "target_active": 1.0 if policy is None else active_fraction,
898
+ "warmup": warmup_steps,
899
+ "explore": explore_fraction if policy is not None else 0.0,
900
+ "train_loss": losses["train"],
901
+ "val_loss": losses["val"],
902
+ "elapsed_s": elapsed_s,
903
+ "step_ms": step_ms,
904
+ "tokens_per_s": tokens_per_s,
905
+ }
906
+ if masker is None:
907
+ row.update({
908
+ "cosine": float("nan"),
909
+ "norm_ratio": float("nan"),
910
+ "top20_mass": float("nan"),
911
+ "jacc_oracle": float("nan"),
912
+ "stability": float("nan"),
913
+ "active_fraction_real": 1.0,
914
+ "coverage": float("nan"),
915
+ "avg_obs_count": float("nan"),
916
+ "new_active_fraction": float("nan"),
917
+ })
918
+ else:
919
+ for k in sums:
920
+ row[k] = (sums[k] / counts[k]) if counts[k] > 0 else float("nan")
921
+ return row
922
+
923
+ def print_summary(rows: List[Dict[str, float | str]]) -> None:
924
+ print("\nSummary")
925
+ header = (
926
+ f"{'run':>22s} {'mode':>19s} {'target':>7s} {'actual':>7s} {'warm':>5s} {'expl':>5s} "
927
+ f"{'val':>8s} {'train':>8s} {'ms':>8s} {'tok/s':>9s} {'cos':>7s} {'top20':>7s} {'jacc':>7s} "
928
+ f"{'stable':>7s} {'cover':>7s} {'new':>7s}"
929
+ )
930
+ print(header)
931
+ print("-" * len(header))
932
+ for r in rows:
933
+ print(
934
+ f"{str(r['run']):>22s} "
935
+ f"{str(r['mode']):>19s} "
936
+ f"{float(r['target_active']):7.3f} "
937
+ f"{float(r['active_fraction_real']):7.3f} "
938
+ f"{int(float(r['warmup'])):5d} "
939
+ f"{float(r['explore']):5.2f} "
940
+ f"{float(r['val_loss']):8.4f} "
941
+ f"{float(r['train_loss']):8.4f} "
942
+ f"{float(r.get('step_ms', float('nan'))):8.2f} "
943
+ f"{float(r.get('tokens_per_s', float('nan'))):9.0f} "
944
+ f"{float(r['cosine']):7.3f} "
945
+ f"{float(r['top20_mass']):7.3f} "
946
+ f"{float(r['jacc_oracle']):7.3f} "
947
+ f"{float(r['stability']):7.3f} "
948
+ f"{float(r['coverage']):7.3f} "
949
+ f"{float(r['new_active_fraction']):7.3f}"
950
+ )
951
+
952
+
953
+ def parse_args() -> argparse.Namespace:
954
+ p = argparse.ArgumentParser()
955
+ p.add_argument("--text_path", type=str, default=None)
956
+ p.add_argument("--synthetic_sentences", type=int, default=12000)
957
+ p.add_argument("--steps", type=int, default=1000)
958
+ p.add_argument("--quick", action="store_true")
959
+ p.add_argument("--batch_size", type=int, default=32)
960
+ p.add_argument("--block_size", type=int, default=64)
961
+ p.add_argument("--n_layer", type=int, default=2)
962
+ p.add_argument("--n_head", type=int, default=4)
963
+ p.add_argument("--n_embd", type=int, default=64)
964
+ p.add_argument("--dropout", type=float, default=0.0)
965
+ p.add_argument("--lr", type=float, default=3e-4)
966
+ p.add_argument("--weight_decay", type=float, default=0.0)
967
+ p.add_argument("--active_fractions", type=float, nargs="+", default=[0.05, 0.02])
968
+ p.add_argument("--policies", type=str, nargs="+", default=["oracle_current", "predicted_magnitude", "random"])
969
+ p.add_argument(
970
+ "--backward_modes",
971
+ type=str,
972
+ nargs="+",
973
+ default=["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"],
974
+ )
975
+ p.add_argument("--explore_fractions", type=float, nargs="+", default=[0.0])
976
+ p.add_argument("--warmup_steps_list", type=int, nargs="+", default=[5])
977
+ p.add_argument("--mass_beta", type=float, default=0.95)
978
+ p.add_argument("--unobserved_decay", type=float, default=1.0)
979
+ p.add_argument("--mass_init", type=float, default=0.0)
980
+ p.add_argument("--ucb_alpha", type=float, default=1.0)
981
+ p.add_argument("--freeze_non_linear_when_sparse", action="store_true")
982
+ p.add_argument("--eval_interval", type=int, default=200)
983
+ p.add_argument("--eval_iters", type=int, default=20)
984
+ p.add_argument("--seed", type=int, default=7)
985
+ p.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "cuda", "mps"])
986
+ p.add_argument("--audit_every", type=int, default=0, help="Dense audit interval after warmup. 0 disables audits except oracle_current.")
987
+ p.add_argument("--kernel_backend", type=str, default="torch", choices=["torch", "metal"], help="Use PyTorch fallback or the optional Metal active-row Linear backward kernel.")
988
+ p.add_argument("--benchmark_sync", action="store_true", help="Synchronize after every train step for fair wall-clock benchmarking on async devices.")
989
+ p.add_argument("--unpaired_seeds", action="store_true", help="Use different init seeds per run instead of paired seeds.")
990
+ p.add_argument("--verbose", action="store_true")
991
+ return p.parse_args()
992
+
993
+
994
+ def main() -> None:
995
+ global USE_METAL_KERNEL
996
+ args = parse_args()
997
+ USE_METAL_KERNEL = args.kernel_backend == "metal"
998
+ if USE_METAL_KERNEL and sparse_linear_metal is None:
999
+ raise RuntimeError(
1000
+ "--kernel_backend metal requested, but sparse_linear_metal is not importable. "
1001
+ "Build in the repo venv (PyTorch must be importable during setup): "
1002
+ "cd sparse_linear_v10_metal && python3 -m pip install -e . --no-build-isolation"
1003
+ )
1004
+ if args.quick:
1005
+ args.steps = 40
1006
+ args.eval_iters = 2
1007
+ args.batch_size = 8
1008
+ args.block_size = 32
1009
+ args.n_layer = 1
1010
+ args.n_embd = 32
1011
+ args.n_head = 4
1012
+ args.synthetic_sentences = 1200
1013
+ args.active_fractions = [0.05]
1014
+ args.policies = ["predicted_magnitude", "random"]
1015
+ args.backward_modes = ["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
1016
+ args.explore_fractions = [0.0]
1017
+ args.warmup_steps_list = [5]
1018
+ args.audit_every = 10
1019
+
1020
+ valid_policies = {"predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"}
1021
+ valid_modes = {"masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"}
1022
+ for pol in args.policies:
1023
+ if pol not in valid_policies:
1024
+ raise ValueError(f"Unknown policy {pol!r}. Valid policies: {sorted(valid_policies)}")
1025
+ for mode in args.backward_modes:
1026
+ if mode not in valid_modes:
1027
+ raise ValueError(f"Unknown backward mode {mode!r}. Valid modes: {sorted(valid_modes)}")
1028
+
1029
+ set_seed(args.seed)
1030
+ dev = args.device if args.device != "auto" else default_device()
1031
+ print(f"device={dev}")
1032
+ corpus = CharCorpus(load_text(args), args.block_size, dev)
1033
+ print(f"vocab_size={corpus.vocab_size} train_tokens={len(corpus.train_data)} val_tokens={len(corpus.val_data)}")
1034
+ print(f"policies={args.policies}")
1035
+ print(f"backward_modes={args.backward_modes}")
1036
+ print(f"active_fractions={args.active_fractions}")
1037
+ print(f"warmup_steps_list={args.warmup_steps_list} explore_fractions={args.explore_fractions}")
1038
+ print(f"mass_init={args.mass_init} mass_beta={args.mass_beta} ucb_alpha={args.ucb_alpha}")
1039
+ print(f"paired_seeds={not args.unpaired_seeds}")
1040
+ print(f"kernel_backend={args.kernel_backend} benchmark_sync={args.benchmark_sync} metal_available={sparse_linear_metal is not None}")
1041
+ print(f"audit_every={args.audit_every} (0 means no dense audit after warmup, except oracle_current)")
1042
+
1043
+ tmp_model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev)
1044
+ total_params, linear_params, linear_frac = parameter_fractions(tmp_model)
1045
+ del tmp_model
1046
+ print(f"params total={total_params} linear={linear_params} linear_fraction={linear_frac:.3f}")
1047
+ if args.freeze_non_linear_when_sparse:
1048
+ print("freeze_non_linear_when_sparse=True: embeddings/layernorm/etc. are frozen in sparse runs")
1049
+ else:
1050
+ print("freeze_non_linear_when_sparse=False: non-Linear params are still updated densely")
1051
+
1052
+ if args.dropout != 0.0:
1053
+ print("warning: dropout is nonzero; dense audit and sparse training passes may see different dropout masks")
1054
+
1055
+ rows: List[Dict[str, float | str]] = []
1056
+ print("\nRunning dense baseline")
1057
+ rows.append(
1058
+ train_run(
1059
+ corpus,
1060
+ args,
1061
+ policy=None,
1062
+ backward_mode=None,
1063
+ active_fraction=1.0,
1064
+ warmup_steps=0,
1065
+ explore_fraction=0.0,
1066
+ seed_offset=0,
1067
+ )
1068
+ )
1069
+
1070
+ seed_offset = 100
1071
+ for mode in args.backward_modes:
1072
+ for af in args.active_fractions:
1073
+ for pol in args.policies:
1074
+ explore_values = args.explore_fractions if pol in {"predicted_magnitude", "ucb_magnitude"} else [0.0]
1075
+ for warmup in args.warmup_steps_list:
1076
+ for explore in explore_values:
1077
+ print(
1078
+ f"\nRunning mode={mode}, policy={pol}, "
1079
+ f"active_fraction={af:.3f}, warmup={warmup}, explore={explore:.2f}"
1080
+ )
1081
+ rows.append(
1082
+ train_run(
1083
+ corpus,
1084
+ args,
1085
+ policy=pol, # type: ignore[arg-type]
1086
+ backward_mode=mode, # type: ignore[arg-type]
1087
+ active_fraction=af,
1088
+ warmup_steps=warmup,
1089
+ explore_fraction=explore,
1090
+ seed_offset=seed_offset,
1091
+ )
1092
+ )
1093
+ seed_offset += 1
1094
+
1095
+ print_summary(rows)
1096
+
1097
+ print("\nNotes")
1098
+ print(" masked_optimizer is the v7-style dense-backward simulation control.")
1099
+ print(" sparse_dW_full_dX uses custom Linear backward: sparse weight/bias grads, full input gradient.")
1100
+ print(" sparse_dW_sparse_dX uses custom Linear backward: sparse weight/bias grads and sparse input gradient.")
1101
+ print(" oracle_current uses dense audit gradients to choose rows; it is an upper bound.")
1102
+ print(" predicted_magnitude uses EMA mass from active/observed rows only.")
1103
+ print(" random is the sparse-support control.")
1104
+ print(" v10 does not compute dense audit gradients after warmup unless --audit_every > 0, except oracle_current.")
1105
+ print(" predicted_magnitude updates EMA statistics only from active rows observed by the training backward pass.")
1106
+ print(" kernel_backend=torch uses the portable PyTorch fallback; kernel_backend=metal uses sparse_linear_metal if installed.")
1107
+ print(" Use --benchmark_sync on MPS/CUDA for honest step_ms/tokens_per_s comparisons.")
1108
+ print(" cosine/top20/jacc require --audit_every > 0, otherwise there is no dense reference gradient.")
1109
+
1110
+
1111
+ if __name__ == "__main__":
1112
+ main()
experiments/sparse_linear_v11_gather_vs_metal/README.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sparse Transformer v10: Metal-backed active-row Linear backward benchmark
2
+
3
+ This bundle is the first empirical optimization test rather than only a correctness prototype.
4
+ It compares dense Transformer training against sparse active-row backward variants and reports
5
+ validation loss plus wall-clock metrics (`step_ms`, `tokens_per_s`).
6
+
7
+ ## Files
8
+
9
+ - `sparse_transformer_v10.py` — training + benchmark harness.
10
+ - `sparse_linear.metal` — Metal kernels for active-row `dW/db` and sparse `dX`.
11
+ - `sparse_linear_ops.mm` — PyTorch/MPS extension glue.
12
+ - `setup.py` — builds the extension and compiles the `.metallib`.
13
+
14
+ ## Install the Metal extension
15
+
16
+ From this directory on macOS with PyTorch MPS available:
17
+
18
+ ```bash
19
+ python3 -m pip install -e .
20
+ ```
21
+
22
+ This uses `xcrun metal` and `xcrun metallib`, so Xcode command-line tools are required.
23
+
24
+ ## Correctness/sanity run with PyTorch fallback
25
+
26
+ ```bash
27
+ python3 sparse_transformer_v10.py \
28
+ --device mps \
29
+ --steps 2000 \
30
+ --active_fractions 0.05 0.02 \
31
+ --warmup_steps_list 5 \
32
+ --policies predicted_magnitude random \
33
+ --backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
34
+ --audit_every 0 \
35
+ --kernel_backend torch \
36
+ --benchmark_sync
37
+ ```
38
+
39
+ ## Metal benchmark run
40
+
41
+ ```bash
42
+ python3 sparse_transformer_v10.py \
43
+ --device mps \
44
+ --steps 2000 \
45
+ --active_fractions 0.05 0.02 \
46
+ --warmup_steps_list 5 \
47
+ --policies predicted_magnitude random \
48
+ --backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
49
+ --audit_every 0 \
50
+ --kernel_backend metal \
51
+ --benchmark_sync
52
+ ```
53
+
54
+ ## Empirical pass/fail criteria
55
+
56
+ A genuine optimization would show:
57
+
58
+ 1. `predicted_magnitude` validation loss close to the PyTorch fallback v9/v10 result.
59
+ 2. `predicted_magnitude` much better than `random` at the same active fraction.
60
+ 3. `--kernel_backend metal` has lower `step_ms` or higher `tokens_per_s` than `--kernel_backend torch` and ideally dense baseline.
61
+
62
+ The first Metal kernels are intentionally simple and fp32-only. They are meant to prove or disprove the acceleration path before investing in tiled half-precision kernels.
experiments/sparse_linear_v11_gather_vs_metal/input.txt ADDED
The diff for this file is too large to render. See raw diff
 
experiments/sparse_linear_v11_gather_vs_metal/setup.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+ from torch.utils.cpp_extension import CppExtension, BuildExtension
3
+
4
+ setup(
5
+ name='sparse_linear',
6
+ ext_modules=[
7
+ CppExtension(
8
+ name='sparse_linear',
9
+ sources=['sparse_linear_ops.mm'],
10
+ extra_compile_args=['-ObjC++'],
11
+ extra_link_args=['-framework', 'Metal', '-framework', 'Foundation']
12
+ )
13
+ ],
14
+ cmdclass={'build_ext': BuildExtension}
15
+ )
experiments/sparse_linear_v11_gather_vs_metal/sparse_linear.metal ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <metal_stdlib>
2
+ using namespace metal;
3
+
4
+ struct SparseLinearParams {
5
+ uint32_t N; // flattened batch/token rows
6
+ uint32_t In; // in_features
7
+ uint32_t Out; // out_features
8
+ uint32_t K; // active rows count
9
+ };
10
+
11
+ kernel void sparse_linear_grad_w_float(
12
+ device const float* x [[buffer(0)]], // [N, In]
13
+ device const float* gy [[buffer(1)]], // [N, Out]
14
+ device const int64_t* active [[buffer(2)]], // [K]
15
+ device float* grad_w [[buffer(3)]], // [Out, In], zeroed by caller
16
+ device float* grad_b [[buffer(4)]], // [Out], zeroed by caller
17
+ constant SparseLinearParams& p [[buffer(5)]],
18
+ uint2 tid [[thread_position_in_grid]]) {
19
+ uint k = tid.y;
20
+ uint c = tid.x;
21
+ if (k >= p.K || c >= p.In) return;
22
+
23
+ int64_t row64 = active[k];
24
+ if (row64 < 0 || row64 >= (int64_t)p.Out) return;
25
+ uint row = (uint)row64;
26
+
27
+ float acc = 0.0f;
28
+ for (uint n = 0; n < p.N; ++n) {
29
+ acc += gy[n * p.Out + row] * x[n * p.In + c];
30
+ }
31
+ grad_w[row * p.In + c] = acc;
32
+
33
+ // One thread per active row computes bias.
34
+ if (c == 0) {
35
+ float bacc = 0.0f;
36
+ for (uint n = 0; n < p.N; ++n) {
37
+ bacc += gy[n * p.Out + row];
38
+ }
39
+ grad_b[row] = bacc;
40
+ }
41
+ }
42
+
43
+ kernel void sparse_linear_grad_x_float(
44
+ device const float* gy [[buffer(0)]], // [N, Out]
45
+ device const float* weight [[buffer(1)]], // [Out, In]
46
+ device const int64_t* active [[buffer(2)]], // [K]
47
+ device float* grad_x [[buffer(3)]], // [N, In]
48
+ constant SparseLinearParams& p [[buffer(4)]],
49
+ uint2 tid [[thread_position_in_grid]]) {
50
+ uint c = tid.x;
51
+ uint n = tid.y;
52
+ if (n >= p.N || c >= p.In) return;
53
+
54
+ float acc = 0.0f;
55
+ for (uint k = 0; k < p.K; ++k) {
56
+ int64_t row64 = active[k];
57
+ if (row64 < 0 || row64 >= (int64_t)p.Out) continue;
58
+ uint row = (uint)row64;
59
+ acc += gy[n * p.Out + row] * weight[row * p.In + c];
60
+ }
61
+ grad_x[n * p.In + c] = acc;
62
+ }
experiments/sparse_linear_v11_gather_vs_metal/sparse_linear_ops.mm ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <ATen/ATen.h>
3
+ #include <ATen/mps/MPSStream.h>
4
+ #include <ATen/mps/MPSDevice.h>
5
+ #include <c10/util/Exception.h>
6
+ #include <filesystem>
7
+ #include <dlfcn.h>
8
+ #import <Metal/Metal.h>
9
+ #import <Foundation/Foundation.h>
10
+ #include <mutex>
11
+
12
+ namespace fs = std::filesystem;
13
+
14
+ namespace {
15
+ struct SparseLinearParams {
16
+ uint32_t N;
17
+ uint32_t In;
18
+ uint32_t Out;
19
+ uint32_t K;
20
+ };
21
+
22
+ static id<MTLLibrary> g_lib = nil;
23
+ static id<MTLComputePipelineState> g_pipeline_grad_w = nil;
24
+ static id<MTLComputePipelineState> g_pipeline_grad_x = nil;
25
+ static std::mutex g_mutex;
26
+
27
+ static std::string metallib_path_for_this_module() {
28
+ Dl_info info;
29
+ if (dladdr((void*)&metallib_path_for_this_module, &info) == 0 || info.dli_fname == nullptr) return std::string();
30
+ fs::path so_path(info.dli_fname);
31
+ return (so_path.parent_path() / "sparse_linear_ops.metallib").string();
32
+ }
33
+
34
+ static void ensure_library_locked(id<MTLDevice> device) {
35
+ if (g_lib != nil) return;
36
+ std::string path = metallib_path_for_this_module();
37
+ TORCH_CHECK(!path.empty(), "sparse_linear_ops: failed to locate extension path via dladdr");
38
+ NSString* ns_path = [NSString stringWithUTF8String:path.c_str()];
39
+ NSURL* url = [NSURL fileURLWithPath:ns_path];
40
+ NSError* err = nil;
41
+ g_lib = [device newLibraryWithURL:url error:&err];
42
+ if (g_lib == nil) {
43
+ const char* msg = err ? [[err localizedDescription] UTF8String] : "unknown error";
44
+ TORCH_CHECK(false, "sparse_linear_ops: failed to load metallib at ", path, ": ", msg);
45
+ }
46
+ }
47
+
48
+ static id<MTLComputePipelineState> ensure_pipeline(id<MTLDevice> device, id<MTLComputePipelineState>* pipeline, const char* fn_name) {
49
+ std::lock_guard<std::mutex> lock(g_mutex);
50
+ ensure_library_locked(device);
51
+ if (*pipeline != nil) return *pipeline;
52
+ NSString* ns_fn = [NSString stringWithUTF8String:fn_name];
53
+ id<MTLFunction> fn = [g_lib newFunctionWithName:ns_fn];
54
+ TORCH_CHECK(fn != nil, "sparse_linear_ops: function `", fn_name, "` not found in metallib");
55
+ NSError* err = nil;
56
+ *pipeline = [device newComputePipelineStateWithFunction:fn error:&err];
57
+ if (*pipeline == nil) {
58
+ const char* msg = err ? [[err localizedDescription] UTF8String] : "unknown error";
59
+ TORCH_CHECK(false, "sparse_linear_ops: failed to create pipeline for ", fn_name, ": ", msg);
60
+ }
61
+ return *pipeline;
62
+ }
63
+
64
+ static inline id<MTLBuffer> storage_as_mtlbuffer(const at::Tensor& t) {
65
+ void* ctx = t.storage().data_ptr().get_context();
66
+ TORCH_CHECK(ctx != nullptr, "sparse_linear_ops: expected MPS tensor storage with MTLBuffer context");
67
+ return (__bridge id<MTLBuffer>)ctx;
68
+ }
69
+
70
+ static inline NSUInteger storage_offset_bytes(const at::Tensor& t) {
71
+ return (NSUInteger)(t.storage_offset() * (int64_t)t.element_size());
72
+ }
73
+
74
+ static void check_mps_float_contig(const at::Tensor& t, const char* name) {
75
+ TORCH_CHECK(t.device().is_mps(), name, " must be on MPS");
76
+ TORCH_CHECK(t.dtype() == at::kFloat, name, " must be float32 for v10 kernel");
77
+ TORCH_CHECK(t.is_contiguous(), name, " must be contiguous");
78
+ }
79
+ } // namespace
80
+
81
+ std::vector<at::Tensor> sparse_linear_grad_wb(at::Tensor x2d, at::Tensor gy2d, at::Tensor active_idx, int64_t out_features) {
82
+ check_mps_float_contig(x2d, "x2d");
83
+ check_mps_float_contig(gy2d, "gy2d");
84
+ TORCH_CHECK(active_idx.device().is_mps(), "active_idx must be on MPS");
85
+ TORCH_CHECK(active_idx.dtype() == at::kLong, "active_idx must be int64");
86
+ TORCH_CHECK(active_idx.is_contiguous(), "active_idx must be contiguous");
87
+ TORCH_CHECK(x2d.dim() == 2 && gy2d.dim() == 2 && active_idx.dim() == 1, "bad dims");
88
+ TORCH_CHECK(x2d.size(0) == gy2d.size(0), "x2d and gy2d N mismatch");
89
+ TORCH_CHECK(gy2d.size(1) == out_features, "gy2d width must equal out_features");
90
+
91
+ int64_t N = x2d.size(0);
92
+ int64_t In = x2d.size(1);
93
+ int64_t Out = out_features;
94
+ int64_t K = active_idx.numel();
95
+ auto grad_w = at::zeros({Out, In}, x2d.options());
96
+ auto grad_b = at::zeros({Out}, x2d.options());
97
+ if (K == 0) return {grad_w, grad_b};
98
+
99
+ id<MTLDevice> device = (id<MTLDevice>)at::mps::MPSDevice::getInstance()->device();
100
+ id<MTLComputePipelineState> pipeline = ensure_pipeline(device, &g_pipeline_grad_w, "sparse_linear_grad_w_float");
101
+ at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
102
+ TORCH_CHECK(stream != nullptr, "failed to get current MPS stream");
103
+ stream->endKernelCoalescing();
104
+ id<MTLComputeCommandEncoder> encoder = (id<MTLComputeCommandEncoder>)stream->commandEncoder();
105
+ TORCH_CHECK(encoder != nil, "failed to get MTLComputeCommandEncoder");
106
+ [encoder setComputePipelineState:pipeline];
107
+
108
+ auto set_tensor = [&](const at::Tensor& t, int idx) {
109
+ [encoder setBuffer:storage_as_mtlbuffer(t) offset:storage_offset_bytes(t) atIndex:(NSUInteger)idx];
110
+ };
111
+ set_tensor(x2d, 0);
112
+ set_tensor(gy2d, 1);
113
+ set_tensor(active_idx, 2);
114
+ set_tensor(grad_w, 3);
115
+ set_tensor(grad_b, 4);
116
+ SparseLinearParams prm{(uint32_t)N, (uint32_t)In, (uint32_t)Out, (uint32_t)K};
117
+ [encoder setBytes:&prm length:sizeof(SparseLinearParams) atIndex:5];
118
+
119
+ MTLSize tg = MTLSizeMake(16, 16, 1);
120
+ MTLSize grid = MTLSizeMake((NSUInteger)((In + 15) / 16), (NSUInteger)((K + 15) / 16), 1);
121
+ [encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg];
122
+ stream->endKernelCoalescing();
123
+ return {grad_w, grad_b};
124
+ }
125
+
126
+ at::Tensor sparse_linear_grad_x(at::Tensor gy2d, at::Tensor weight, at::Tensor active_idx) {
127
+ check_mps_float_contig(gy2d, "gy2d");
128
+ check_mps_float_contig(weight, "weight");
129
+ TORCH_CHECK(active_idx.device().is_mps(), "active_idx must be on MPS");
130
+ TORCH_CHECK(active_idx.dtype() == at::kLong, "active_idx must be int64");
131
+ TORCH_CHECK(active_idx.is_contiguous(), "active_idx must be contiguous");
132
+ TORCH_CHECK(gy2d.dim() == 2 && weight.dim() == 2 && active_idx.dim() == 1, "bad dims");
133
+ int64_t N = gy2d.size(0);
134
+ int64_t Out = gy2d.size(1);
135
+ int64_t In = weight.size(1);
136
+ int64_t K = active_idx.numel();
137
+ TORCH_CHECK(weight.size(0) == Out, "weight out_features must match gy2d width");
138
+ auto grad_x = at::zeros({N, In}, gy2d.options());
139
+ if (K == 0) return grad_x;
140
+
141
+ id<MTLDevice> device = (id<MTLDevice>)at::mps::MPSDevice::getInstance()->device();
142
+ id<MTLComputePipelineState> pipeline = ensure_pipeline(device, &g_pipeline_grad_x, "sparse_linear_grad_x_float");
143
+ at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
144
+ TORCH_CHECK(stream != nullptr, "failed to get current MPS stream");
145
+ stream->endKernelCoalescing();
146
+ id<MTLComputeCommandEncoder> encoder = (id<MTLComputeCommandEncoder>)stream->commandEncoder();
147
+ TORCH_CHECK(encoder != nil, "failed to get MTLComputeCommandEncoder");
148
+ [encoder setComputePipelineState:pipeline];
149
+ auto set_tensor = [&](const at::Tensor& t, int idx) {
150
+ [encoder setBuffer:storage_as_mtlbuffer(t) offset:storage_offset_bytes(t) atIndex:(NSUInteger)idx];
151
+ };
152
+ set_tensor(gy2d, 0);
153
+ set_tensor(weight, 1);
154
+ set_tensor(active_idx, 2);
155
+ set_tensor(grad_x, 3);
156
+ SparseLinearParams prm{(uint32_t)N, (uint32_t)In, (uint32_t)Out, (uint32_t)K};
157
+ [encoder setBytes:&prm length:sizeof(SparseLinearParams) atIndex:4];
158
+ MTLSize tg = MTLSizeMake(16, 16, 1);
159
+ MTLSize grid = MTLSizeMake((NSUInteger)((In + 15) / 16), (NSUInteger)((N + 15) / 16), 1);
160
+ [encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg];
161
+ stream->endKernelCoalescing();
162
+ return grad_x;
163
+ }
164
+
165
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
166
+ m.def("sparse_linear_grad_wb", &sparse_linear_grad_wb, "Sparse active-row Linear dW/db (Metal/MPS, fp32)");
167
+ m.def("sparse_linear_grad_x", &sparse_linear_grad_x, "Sparse active-row Linear dX (Metal/MPS, fp32)");
168
+ }
experiments/sparse_linear_v11_gather_vs_metal/sparse_transformer_v11.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sparse Transformer v12: Hardware-Sympathetic Chunked Sparsity.
3
+
4
+ This version groups rows into hardware-friendly "Chunks" (e.g., 64 rows per chunk).
5
+ By selecting entire chunks, PyTorch can use zero-copy Strided Views. This completely
6
+ bypasses the slow index_select/gather memory copying and feeds data directly into
7
+ the AMX / Tensor Cores at native dense speeds.
8
+
9
+ Run:
10
+ python3 sparse_transformer_v12.py --device mps --benchmark_sync
11
+ """
12
+
13
+ import argparse
14
+ import math
15
+ import random
16
+ import time
17
+ from typing import Dict, List, Literal, Optional, Tuple
18
+
19
+ import torch
20
+
21
+ torch.set_num_threads(1)
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+ def sync_device(device: str) -> None:
26
+ if device == "cuda" and torch.cuda.is_available():
27
+ torch.cuda.synchronize()
28
+ elif device == "mps" and hasattr(torch, "mps"):
29
+ torch.mps.synchronize()
30
+
31
+ Policy = Literal["predicted_magnitude", "oracle_current", "random"]
32
+ BackwardMode = Literal["dense_baseline", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
33
+
34
+ def set_seed(seed: int) -> None:
35
+ random.seed(seed)
36
+ torch.manual_seed(seed)
37
+
38
+ def make_cpu_generator(seed: int) -> torch.Generator:
39
+ gen = torch.Generator(device="cpu")
40
+ gen.manual_seed(seed)
41
+ return gen
42
+
43
+ # -----------------------------
44
+ # Data
45
+ # -----------------------------
46
+ def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
47
+ rng = random.Random(seed)
48
+ words =["ada", "turing", "grace", "lovelace", "gradients", "tokens", "circuits", "features", "boldly", "strangely"]
49
+ return "\n".join(" ".join(rng.choices(words, k=rng.randint(4, 10))) + "." for _ in range(n_sentences))
50
+
51
+ class CharCorpus:
52
+ def __init__(self, text: str, block_size: int, device: str):
53
+ chars = sorted(set(text))
54
+ self.stoi = {ch: i for i, ch in enumerate(chars)}
55
+ self.itos = {i: ch for ch, i in self.stoi.items()}
56
+ self.vocab_size = len(chars)
57
+ self.block_size = block_size
58
+ self.device = device
59
+ data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
60
+ self.train_data = data[:int(0.9 * len(data))]
61
+ self.val_data = data[int(0.9 * len(data)):]
62
+
63
+ def get_batch(self, split: str, batch_size: int, generator: Optional[torch.Generator] = None) -> Tuple[torch.Tensor, torch.Tensor]:
64
+ data = self.train_data if split == "train" else self.val_data
65
+ ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator)
66
+ x = torch.stack([data[i : i + self.block_size] for i in ix])
67
+ y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
68
+ return x.to(self.device), y.to(self.device)
69
+
70
+ # -----------------------------
71
+ # Chunked Sparse Autograd
72
+ # -----------------------------
73
+
74
+ class ChunkedMaskedLinear(torch.autograd.Function):
75
+ @staticmethod
76
+ def forward(ctx, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], active_chunks: torch.Tensor, chunk_size: int, sparse_dx: bool) -> torch.Tensor:
77
+ ctx.save_for_backward(x, weight, active_chunks)
78
+ ctx.has_bias = bias is not None
79
+ ctx.sparse_dx = sparse_dx
80
+ ctx.chunk_size = chunk_size
81
+ return F.linear(x, weight, bias)
82
+
83
+ @staticmethod
84
+ def backward(ctx, grad_y: torch.Tensor):
85
+ x, weight, active_chunks = ctx.saved_tensors
86
+ chunk_size = ctx.chunk_size
87
+
88
+ x_flat = x.reshape(-1, x.shape[-1])
89
+ gy_flat = grad_y.reshape(-1, grad_y.shape[-1])
90
+
91
+ # Initialize full zero gradients
92
+ grad_w = torch.zeros_like(weight)
93
+ grad_b = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if ctx.has_bias else None
94
+
95
+ if ctx.sparse_dx:
96
+ grad_x_flat = torch.zeros_like(x_flat)
97
+ else:
98
+ grad_x_flat = gy_flat @ weight
99
+
100
+ # THE MAGIC: Zero-copy Strided Views
101
+ for c_idx in active_chunks.tolist():
102
+ start = c_idx * chunk_size
103
+ end = start + chunk_size
104
+
105
+ # These are views! No memory allocation, no gathering.
106
+ gy_slice = gy_flat[:, start:end]
107
+ w_slice = weight[start:end, :]
108
+
109
+ # Triggers dense hardware matmul (AMX / Tensor Cores) directly
110
+ grad_w[start:end, :] = gy_slice.t() @ x_flat
111
+
112
+ if ctx.has_bias:
113
+ grad_b[start:end] = gy_slice.sum(dim=0)
114
+
115
+ if ctx.sparse_dx:
116
+ grad_x_flat += gy_slice @ w_slice
117
+
118
+ return grad_x_flat.reshape(x.shape), grad_w, grad_b, None, None, None
119
+
120
+ class SparseLinear(nn.Linear):
121
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
122
+ super().__init__(in_features, out_features, bias=bias)
123
+ self.sparse_enabled = False
124
+ self.sparse_dx = False
125
+ self.active_chunks: Optional[torch.Tensor] = None
126
+
127
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
128
+ if not self.sparse_enabled or self.active_chunks is None:
129
+ return F.linear(x, self.weight, self.bias)
130
+ return ChunkedMaskedLinear.apply(x, self.weight, self.bias, self.active_chunks, getattr(self, 'chunk_size', 64), self.sparse_dx)
131
+
132
+ # -----------------------------
133
+ # Mini GPT
134
+ # -----------------------------
135
+ class CausalSelfAttention(nn.Module):
136
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
137
+ super().__init__()
138
+ assert n_embd % n_head == 0
139
+ self.n_head = n_head
140
+ self.head_dim = n_embd // n_head
141
+ self.c_attn = SparseLinear(n_embd, 3 * n_embd)
142
+ self.c_proj = SparseLinear(n_embd, n_embd)
143
+ self.dropout = nn.Dropout(dropout)
144
+ self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
145
+
146
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
147
+ B, T, C = x.shape
148
+ qkv = self.c_attn(x)
149
+ q, k, v = qkv.split(C, dim=2)
150
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
151
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
152
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
153
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
154
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
155
+ att = F.softmax(att, dim=-1)
156
+ att = self.dropout(att)
157
+ y = att @ v
158
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
159
+ return self.c_proj(y)
160
+
161
+ class FeedForward(nn.Module):
162
+ def __init__(self, n_embd: int, dropout: float):
163
+ super().__init__()
164
+ self.c_fc = SparseLinear(n_embd, 4 * n_embd)
165
+ self.c_proj = SparseLinear(4 * n_embd, n_embd)
166
+ self.dropout = nn.Dropout(dropout)
167
+
168
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
169
+ return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
170
+
171
+ class Block(nn.Module):
172
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
173
+ super().__init__()
174
+ self.ln1 = nn.LayerNorm(n_embd)
175
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
176
+ self.ln2 = nn.LayerNorm(n_embd)
177
+ self.mlp = FeedForward(n_embd, dropout)
178
+
179
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
180
+ x = x + self.attn(self.ln1(x))
181
+ x = x + self.mlp(self.ln2(x))
182
+ return x
183
+
184
+ class MiniGPT(nn.Module):
185
+ def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float):
186
+ super().__init__()
187
+ self.block_size = block_size
188
+ self.tok_emb = nn.Embedding(vocab_size, n_embd)
189
+ self.pos_emb = nn.Embedding(block_size, n_embd)
190
+ self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
191
+ self.ln_f = nn.LayerNorm(n_embd)
192
+
193
+ # Standard nn.Linear for the LM head so it isn't restricted by chunk sizes
194
+ # and correctly calculates cross-entropy probabilities.
195
+ self.lm_head = nn.Linear(n_embd, vocab_size)
196
+
197
+ def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
198
+ B, T = idx.shape
199
+ pos = torch.arange(T, device=idx.device)
200
+ x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
201
+ x = self.blocks(x)
202
+ x = self.ln_f(x)
203
+ logits = self.lm_head(x)
204
+ loss = None
205
+ if targets is not None:
206
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
207
+ return logits, loss
208
+
209
+ def get_sparse_linears(model):
210
+ return[m for m in model.modules() if isinstance(m, SparseLinear)]
211
+
212
+ # -----------------------------
213
+ # Chunk Masker
214
+ # -----------------------------
215
+ class ChunkMasker:
216
+ def __init__(self, model: nn.Module, policy: Policy, active_fraction: float, chunk_size: int, device: str):
217
+ self.policy = policy
218
+ self.active_fraction = active_fraction
219
+ self.chunk_size = chunk_size
220
+ self.device = device
221
+
222
+ self.linears = get_sparse_linears(model)
223
+ self.module_to_chunk_ids = {}
224
+ offset = 0
225
+ for m in self.linears:
226
+ assert m.out_features % chunk_size == 0, f"out_features {m.out_features} not divisible by chunk size {chunk_size}"
227
+ n_chunks = m.out_features // chunk_size
228
+ self.module_to_chunk_ids[m] = torch.arange(offset, offset + n_chunks, device=device)
229
+ offset += n_chunks
230
+
231
+ self.n_chunks = offset
232
+ self.predicted_mass = torch.zeros(self.n_chunks, device=device)
233
+ self.active_chunks = torch.zeros(self.n_chunks, dtype=torch.bool, device=device)
234
+
235
+ def choose_active(self, step: int, warmup_steps: int):
236
+ if step < warmup_steps:
237
+ self.active_chunks.fill_(True)
238
+ for m, ids in self.module_to_chunk_ids.items():
239
+ m.active_chunks = torch.arange(len(ids), device=self.device)
240
+ return
241
+
242
+ k = max(1, int(self.active_fraction * self.n_chunks))
243
+ self.active_chunks.fill_(False)
244
+
245
+ if self.policy == "random":
246
+ self.active_chunks[torch.randperm(self.n_chunks, device=self.device)[:k]] = True
247
+ elif self.policy == "predicted_magnitude":
248
+ # Add tiny noise for tie-breaking
249
+ scores = self.predicted_mass + 1e-9 * torch.rand_like(self.predicted_mass)
250
+ self.active_chunks[torch.topk(scores, k=k).indices] = True
251
+
252
+ for m, ids in self.module_to_chunk_ids.items():
253
+ global_active = self.active_chunks[ids]
254
+ local_ids = torch.arange(len(ids), device=self.device)
255
+ m.active_chunks = local_ids[global_active]
256
+
257
+ @torch.no_grad()
258
+ def update_predictor(self, mass_beta=0.95):
259
+ # Calculate true L2 norm of the gradient for each CHUNK
260
+ current_mass = torch.zeros_like(self.predicted_mass)
261
+ for m, ids in self.module_to_chunk_ids.items():
262
+ if m.weight.grad is None: continue
263
+ # Reshape[Out, In] -> [n_chunks, chunk_size, In], square, sum across chunk_size and In
264
+ w_sq = m.weight.grad.square().view(len(ids), self.chunk_size, -1).sum(dim=(1, 2))
265
+ if m.bias is not None and m.bias.grad is not None:
266
+ w_sq += m.bias.grad.square().view(len(ids), self.chunk_size).sum(dim=1)
267
+ current_mass[ids] = torch.sqrt(w_sq + 1e-30)
268
+
269
+ # Only update observed (active) chunks
270
+ observed = self.active_chunks
271
+ self.predicted_mass[observed] = mass_beta * self.predicted_mass[observed] + (1.0 - mass_beta) * current_mass[observed]
272
+
273
+ # -----------------------------
274
+ # Chunked Adam
275
+ # -----------------------------
276
+ class ChunkedAdam:
277
+ def __init__(self, model, lr=3e-4, chunk_size=64):
278
+ self.model = model
279
+ self.lr = lr
280
+ self.chunk_size = chunk_size
281
+ self.state = {}
282
+
283
+ # Keep track of which parameters belong to sparse modules
284
+ self.param_to_sparse_module = {}
285
+ for m in get_sparse_linears(model):
286
+ if m.weight is not None: self.param_to_sparse_module[m.weight] = m
287
+ if m.bias is not None: self.param_to_sparse_module[m.bias] = m
288
+
289
+ def zero_grad(self):
290
+ for p in self.model.parameters(): p.grad = None
291
+
292
+ @torch.no_grad()
293
+ def step(self):
294
+ for p in self.model.parameters():
295
+ if p.grad is None: continue
296
+ if p not in self.state:
297
+ self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
298
+
299
+ exp_avg, exp_avg_sq = self.state[p]["m"], self.state[p]["v"]
300
+
301
+ sparse_module = self.param_to_sparse_module.get(p)
302
+ active_chunks = getattr(sparse_module, 'active_chunks', None) if sparse_module else None
303
+
304
+ if active_chunks is None:
305
+ # Dense update for embeddings, layernorms, LM head, or baseline
306
+ exp_avg.mul_(0.9).add_(p.grad, alpha=0.1)
307
+ exp_avg_sq.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
308
+ update = exp_avg / (torch.sqrt(exp_avg_sq) + 1e-8)
309
+ p.sub_(update, alpha=self.lr)
310
+ else:
311
+ # Sparse update ONLY on active chunks (indices are local per module)
312
+ for local_c in active_chunks.tolist():
313
+ start = local_c * self.chunk_size
314
+ end = (local_c + 1) * self.chunk_size
315
+
316
+ p_chunk = p[start:end]
317
+ g_chunk = p.grad[start:end]
318
+ m_chunk = exp_avg[start:end]
319
+ v_chunk = exp_avg_sq[start:end]
320
+
321
+ m_chunk.mul_(0.9).add_(g_chunk, alpha=0.1)
322
+ v_chunk.mul_(0.999).addcmul_(g_chunk, g_chunk, value=0.001)
323
+
324
+ update = m_chunk / (torch.sqrt(v_chunk) + 1e-8)
325
+ p_chunk.sub_(update, alpha=self.lr)
326
+
327
+ # -----------------------------
328
+ # Training
329
+ # -----------------------------
330
+ def main():
331
+ parser = argparse.ArgumentParser()
332
+ parser.add_argument("--steps", type=int, default=50)
333
+ parser.add_argument("--batch_size", type=int, default=16)
334
+ parser.add_argument("--block_size", type=int, default=256)
335
+ parser.add_argument("--n_layer", type=int, default=4)
336
+ parser.add_argument("--n_head", type=int, default=16)
337
+ parser.add_argument("--n_embd", type=int, default=1024)
338
+ parser.add_argument("--chunk_size", type=int, default=64)
339
+ parser.add_argument("--active_fraction", type=float, default=0.05)
340
+ parser.add_argument("--device", type=str, default="mps")
341
+ parser.add_argument("--benchmark_sync", action="store_true")
342
+ args = parser.parse_args()
343
+
344
+ corpus = CharCorpus(make_synthetic_corpus(), args.block_size, args.device)
345
+
346
+ modes =[
347
+ ("dense_baseline", "dense_baseline"),
348
+ ("predicted_magnitude", "sparse_dW_full_dX"),
349
+ ("predicted_magnitude", "sparse_dW_sparse_dX")
350
+ ]
351
+
352
+ print(f"\nModel: {args.n_layer} layers, {args.n_embd} d_model, {args.chunk_size} chunk_size")
353
+ print(f"Batch: {args.batch_size}, Block: {args.block_size}. Active Fraction: {args.active_fraction}\n")
354
+ print(f"{'Run':>20s} | {'Time (s)':>10s} | {'Step (ms)':>10s} | {'Val Loss':>8s}")
355
+ print("-" * 55)
356
+
357
+ for policy, bwd_mode in modes:
358
+ set_seed(42)
359
+ model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, 0.0).to(args.device)
360
+
361
+ for m in get_sparse_linears(model):
362
+ m.chunk_size = args.chunk_size
363
+
364
+ masker = ChunkMasker(model, policy, args.active_fraction, args.chunk_size, args.device) if policy != "dense_baseline" else None
365
+ opt = ChunkedAdam(model, chunk_size=args.chunk_size)
366
+
367
+ if args.benchmark_sync: sync_device(args.device)
368
+ t0 = time.perf_counter()
369
+
370
+ for step in range(args.steps):
371
+ x, y = corpus.get_batch("train", args.batch_size, generator=make_cpu_generator(step))
372
+
373
+ if masker:
374
+ masker.choose_active(step, warmup_steps=5)
375
+ for m in get_sparse_linears(model):
376
+ m.sparse_enabled = True
377
+ m.sparse_dx = (bwd_mode == "sparse_dW_sparse_dX")
378
+ else:
379
+ for m in get_sparse_linears(model):
380
+ m.sparse_enabled = False
381
+ m.active_chunks = None
382
+
383
+ opt.zero_grad()
384
+ _, loss = model(x, y)
385
+ loss.backward()
386
+
387
+ if masker:
388
+ masker.update_predictor()
389
+
390
+ opt.step()
391
+
392
+ if args.benchmark_sync: sync_device(args.device)
393
+ t_elapsed = time.perf_counter() - t0
394
+
395
+ # Eval loss
396
+ model.eval()
397
+ with torch.no_grad():
398
+ x, y = corpus.get_batch("val", args.batch_size, generator=make_cpu_generator(999))
399
+ _, val_loss = model(x, y)
400
+
401
+ # Format the mode strictly for the printout width
402
+ bwd_str = bwd_mode if bwd_mode == "dense_baseline" else ("sparse_full_dX" if "full_dX" in bwd_mode else "sparse_sparse_dX")
403
+ print(f"{bwd_str:>20s} | {t_elapsed:10.2f} | {1000*t_elapsed/args.steps:10.2f} | {val_loss.item():8.4f}")
404
+
405
+ if __name__ == "__main__":
406
+ main()
experiments/sparse_linear_v11_gather_vs_metal/sparse_transformer_v13.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sparse Transformer v14: The Final Architecture.
3
+
4
+ Combines the Hardware-Sympathetic Chunked Sparse backward pass with Cosine Annealing,
5
+ but restores the Chunked Optimizer to prevent Dense memory-bandwidth bottlenecks.
6
+ Benchmarks are isolated to the steady-state phase (after annealing) for accurate timings.
7
+
8
+ Run:
9
+ python3 sparse_transformer_v14.py --device mps --benchmark_sync --n_embd 1024
10
+ """
11
+
12
+ import argparse
13
+ import math
14
+ import random
15
+ import time
16
+ from typing import Dict, List, Literal, Optional, Tuple
17
+
18
+ import torch
19
+
20
+ torch.set_num_threads(1)
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+ def sync_device(device: str) -> None:
25
+ if device == "cuda" and torch.cuda.is_available():
26
+ torch.cuda.synchronize()
27
+ elif device == "mps" and hasattr(torch, "mps"):
28
+ torch.mps.synchronize()
29
+
30
+ Policy = Literal["predicted_magnitude", "oracle_current", "random"]
31
+ BackwardMode = Literal["dense_baseline", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
32
+
33
+ def set_seed(seed: int) -> None:
34
+ random.seed(seed)
35
+ torch.manual_seed(seed)
36
+
37
+ def make_cpu_generator(seed: int) -> torch.Generator:
38
+ gen = torch.Generator(device="cpu")
39
+ gen.manual_seed(seed)
40
+ return gen
41
+
42
+ # -----------------------------
43
+ # Data
44
+ # -----------------------------
45
+ def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
46
+ rng = random.Random(seed)
47
+ words =["ada", "turing", "grace", "lovelace", "gradients", "tokens", "circuits", "features", "boldly", "strangely"]
48
+ return "\n".join(" ".join(rng.choices(words, k=rng.randint(4, 10))) + "." for _ in range(n_sentences))
49
+
50
+ class CharCorpus:
51
+ def __init__(self, text: str, block_size: int, device: str):
52
+ chars = sorted(set(text))
53
+ self.stoi = {ch: i for i, ch in enumerate(chars)}
54
+ self.itos = {i: ch for ch, i in self.stoi.items()}
55
+ self.vocab_size = len(chars)
56
+ self.block_size = block_size
57
+ self.device = device
58
+ data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
59
+ self.train_data = data[:int(0.9 * len(data))]
60
+ self.val_data = data[int(0.9 * len(data)):]
61
+
62
+ def get_batch(self, split: str, batch_size: int, generator: Optional[torch.Generator] = None) -> Tuple[torch.Tensor, torch.Tensor]:
63
+ data = self.train_data if split == "train" else self.val_data
64
+ ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator)
65
+ x = torch.stack([data[i : i + self.block_size] for i in ix])
66
+ y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
67
+ return x.to(self.device), y.to(self.device)
68
+
69
+ # -----------------------------
70
+ # Chunked Sparse Autograd
71
+ # -----------------------------
72
+
73
+ class ChunkedMaskedLinear(torch.autograd.Function):
74
+ @staticmethod
75
+ def forward(ctx, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], active_chunks: torch.Tensor, chunk_size: int, sparse_dx: bool) -> torch.Tensor:
76
+ ctx.save_for_backward(x, weight, active_chunks)
77
+ ctx.has_bias = bias is not None
78
+ ctx.sparse_dx = sparse_dx
79
+ ctx.chunk_size = chunk_size
80
+ return F.linear(x, weight, bias)
81
+
82
+ @staticmethod
83
+ def backward(ctx, grad_y: torch.Tensor):
84
+ x, weight, active_chunks = ctx.saved_tensors
85
+ chunk_size = ctx.chunk_size
86
+
87
+ x_flat = x.reshape(-1, x.shape[-1])
88
+ gy_flat = grad_y.reshape(-1, grad_y.shape[-1])
89
+
90
+ grad_w = torch.zeros_like(weight)
91
+ grad_b = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if ctx.has_bias else None
92
+
93
+ if ctx.sparse_dx:
94
+ grad_x_flat = torch.zeros_like(x_flat)
95
+ else:
96
+ grad_x_flat = gy_flat @ weight
97
+
98
+ # Zero-copy Strided Views
99
+ for c_idx in active_chunks.tolist():
100
+ start = c_idx * chunk_size
101
+ end = start + chunk_size
102
+
103
+ gy_slice = gy_flat[:, start:end]
104
+ w_slice = weight[start:end, :]
105
+
106
+ grad_w[start:end, :] = gy_slice.t() @ x_flat
107
+
108
+ if ctx.has_bias:
109
+ grad_b[start:end] = gy_slice.sum(dim=0)
110
+
111
+ if ctx.sparse_dx:
112
+ grad_x_flat += gy_slice @ w_slice
113
+
114
+ return grad_x_flat.reshape(x.shape), grad_w, grad_b, None, None, None
115
+
116
+ class SparseLinear(nn.Linear):
117
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
118
+ super().__init__(in_features, out_features, bias=bias)
119
+ self.sparse_enabled = False
120
+ self.sparse_dx = False
121
+ self.active_chunks: Optional[torch.Tensor] = None
122
+
123
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
124
+ if not self.sparse_enabled or self.active_chunks is None:
125
+ return F.linear(x, self.weight, self.bias)
126
+ return ChunkedMaskedLinear.apply(x, self.weight, self.bias, self.active_chunks, getattr(self, 'chunk_size', 64), self.sparse_dx)
127
+
128
+ # -----------------------------
129
+ # Mini GPT
130
+ # -----------------------------
131
+ class CausalSelfAttention(nn.Module):
132
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
133
+ super().__init__()
134
+ assert n_embd % n_head == 0
135
+ self.n_head = n_head
136
+ self.head_dim = n_embd // n_head
137
+ self.c_attn = SparseLinear(n_embd, 3 * n_embd)
138
+ self.c_proj = SparseLinear(n_embd, n_embd)
139
+ self.dropout = nn.Dropout(dropout)
140
+ self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
141
+
142
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
143
+ B, T, C = x.shape
144
+ qkv = self.c_attn(x)
145
+ q, k, v = qkv.split(C, dim=2)
146
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
147
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
148
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
149
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
150
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
151
+ att = F.softmax(att, dim=-1)
152
+ att = self.dropout(att)
153
+ y = att @ v
154
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
155
+ return self.c_proj(y)
156
+
157
+ class FeedForward(nn.Module):
158
+ def __init__(self, n_embd: int, dropout: float):
159
+ super().__init__()
160
+ self.c_fc = SparseLinear(n_embd, 4 * n_embd)
161
+ self.c_proj = SparseLinear(4 * n_embd, n_embd)
162
+ self.dropout = nn.Dropout(dropout)
163
+
164
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
165
+ return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
166
+
167
+ class Block(nn.Module):
168
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
169
+ super().__init__()
170
+ self.ln1 = nn.LayerNorm(n_embd)
171
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
172
+ self.ln2 = nn.LayerNorm(n_embd)
173
+ self.mlp = FeedForward(n_embd, dropout)
174
+
175
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
176
+ x = x + self.attn(self.ln1(x))
177
+ x = x + self.mlp(self.ln2(x))
178
+ return x
179
+
180
+ class MiniGPT(nn.Module):
181
+ def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float):
182
+ super().__init__()
183
+ self.block_size = block_size
184
+ self.tok_emb = nn.Embedding(vocab_size, n_embd)
185
+ self.pos_emb = nn.Embedding(block_size, n_embd)
186
+ self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
187
+ self.ln_f = nn.LayerNorm(n_embd)
188
+ self.lm_head = nn.Linear(n_embd, vocab_size)
189
+
190
+ def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
191
+ B, T = idx.shape
192
+ pos = torch.arange(T, device=idx.device)
193
+ x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
194
+ x = self.blocks(x)
195
+ x = self.ln_f(x)
196
+ logits = self.lm_head(x)
197
+ loss = None
198
+ if targets is not None:
199
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
200
+ return logits, loss
201
+
202
+ def get_sparse_linears(model):
203
+ return[m for m in model.modules() if isinstance(m, SparseLinear)]
204
+
205
+ # -----------------------------
206
+ # Chunk Masker with Annealing
207
+ # -----------------------------
208
+ class ChunkMasker:
209
+ def __init__(self, model: nn.Module, policy: Policy, target_fraction: float, chunk_size: int, device: str):
210
+ self.policy = policy
211
+ self.target_fraction = target_fraction
212
+ self.chunk_size = chunk_size
213
+ self.device = device
214
+
215
+ self.linears = get_sparse_linears(model)
216
+ self.module_to_chunk_ids = {}
217
+ offset = 0
218
+ for m in self.linears:
219
+ assert m.out_features % chunk_size == 0, f"out_features {m.out_features} not divisible by chunk size {chunk_size}"
220
+ n_chunks = m.out_features // chunk_size
221
+ self.module_to_chunk_ids[m] = torch.arange(offset, offset + n_chunks, device=device)
222
+ offset += n_chunks
223
+
224
+ self.n_chunks = offset
225
+ self.predicted_mass = torch.zeros(self.n_chunks, device=device)
226
+ self.active_chunks = torch.zeros(self.n_chunks, dtype=torch.bool, device=device)
227
+
228
+ def choose_active(self, step: int, warmup_steps: int, anneal_steps: int):
229
+ # Cosine Annealing Logic
230
+ if step < warmup_steps:
231
+ current_fraction = 1.0
232
+ elif step < warmup_steps + anneal_steps:
233
+ progress = (step - warmup_steps) / anneal_steps
234
+ cosine_mult = 0.5 * (1.0 + math.cos(math.pi * progress))
235
+ current_fraction = self.target_fraction + (1.0 - self.target_fraction) * cosine_mult
236
+ else:
237
+ current_fraction = self.target_fraction
238
+
239
+ if current_fraction >= 0.999:
240
+ self.active_chunks.fill_(True)
241
+ for m, ids in self.module_to_chunk_ids.items():
242
+ m.active_chunks = torch.arange(len(ids), device=self.device)
243
+ return
244
+
245
+ k = max(1, int(current_fraction * self.n_chunks))
246
+ self.active_chunks.fill_(False)
247
+
248
+ if self.policy == "random":
249
+ self.active_chunks[torch.randperm(self.n_chunks, device=self.device)[:k]] = True
250
+ elif self.policy == "predicted_magnitude":
251
+ scores = self.predicted_mass + 1e-9 * torch.rand_like(self.predicted_mass)
252
+ self.active_chunks[torch.topk(scores, k=k).indices] = True
253
+
254
+ for m, ids in self.module_to_chunk_ids.items():
255
+ global_active = self.active_chunks[ids]
256
+ local_ids = torch.arange(len(ids), device=self.device)
257
+ m.active_chunks = local_ids[global_active]
258
+
259
+ @torch.no_grad()
260
+ def update_predictor(self, mass_beta=0.95):
261
+ current_mass = torch.zeros_like(self.predicted_mass)
262
+ for m, ids in self.module_to_chunk_ids.items():
263
+ if m.weight.grad is None: continue
264
+ w_sq = m.weight.grad.square().view(len(ids), self.chunk_size, -1).sum(dim=(1, 2))
265
+ if m.bias is not None and m.bias.grad is not None:
266
+ w_sq += m.bias.grad.square().view(len(ids), self.chunk_size).sum(dim=1)
267
+ current_mass[ids] = torch.sqrt(w_sq + 1e-30)
268
+
269
+ observed = self.active_chunks
270
+ self.predicted_mass[observed] = mass_beta * self.predicted_mass[observed] + (1.0 - mass_beta) * current_mass[observed]
271
+
272
+
273
+ # -----------------------------
274
+ # Chunked Adam (Restored)
275
+ # -----------------------------
276
+ class ChunkedAdam:
277
+ def __init__(self, model, lr=3e-4, chunk_size=64):
278
+ self.model = model
279
+ self.lr = lr
280
+ self.chunk_size = chunk_size
281
+ self.state = {}
282
+
283
+ # Keep track of which parameters belong to sparse modules
284
+ self.param_to_sparse_module = {}
285
+ for m in get_sparse_linears(model):
286
+ if m.weight is not None: self.param_to_sparse_module[m.weight] = m
287
+ if m.bias is not None: self.param_to_sparse_module[m.bias] = m
288
+
289
+ def zero_grad(self):
290
+ for p in self.model.parameters(): p.grad = None
291
+
292
+ @torch.no_grad()
293
+ def step(self):
294
+ for p in self.model.parameters():
295
+ if p.grad is None: continue
296
+ if p not in self.state:
297
+ self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
298
+
299
+ exp_avg, exp_avg_sq = self.state[p]["m"], self.state[p]["v"]
300
+
301
+ sparse_module = self.param_to_sparse_module.get(p)
302
+ active_chunks = getattr(sparse_module, 'active_chunks', None) if sparse_module else None
303
+
304
+ if active_chunks is None:
305
+ # Dense update for embeddings, layernorms, LM head, or baseline
306
+ exp_avg.mul_(0.9).add_(p.grad, alpha=0.1)
307
+ exp_avg_sq.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
308
+ update = exp_avg / (torch.sqrt(exp_avg_sq) + 1e-8)
309
+ p.sub_(update, alpha=self.lr)
310
+ else:
311
+ # Sparse update ONLY on the active chunks
312
+ for local_c in active_chunks.tolist():
313
+ start = local_c * self.chunk_size
314
+ end = (local_c + 1) * self.chunk_size
315
+
316
+ p_chunk = p[start:end]
317
+ g_chunk = p.grad[start:end]
318
+ m_chunk = exp_avg[start:end]
319
+ v_chunk = exp_avg_sq[start:end]
320
+
321
+ m_chunk.mul_(0.9).add_(g_chunk, alpha=0.1)
322
+ v_chunk.mul_(0.999).addcmul_(g_chunk, g_chunk, value=0.001)
323
+
324
+ update = m_chunk / (torch.sqrt(v_chunk) + 1e-8)
325
+ p_chunk.sub_(update, alpha=self.lr)
326
+
327
+ # -----------------------------
328
+ # Training
329
+ # -----------------------------
330
+ def main():
331
+ parser = argparse.ArgumentParser()
332
+ parser.add_argument("--steps", type=int, default=500)
333
+ parser.add_argument("--batch_size", type=int, default=16)
334
+ parser.add_argument("--block_size", type=int, default=256)
335
+ parser.add_argument("--n_layer", type=int, default=4)
336
+ parser.add_argument("--n_head", type=int, default=16)
337
+ parser.add_argument("--n_embd", type=int, default=1024)
338
+ parser.add_argument("--chunk_size", type=int, default=64)
339
+ parser.add_argument("--active_fraction", type=float, default=0.05)
340
+ parser.add_argument("--warmup_steps", type=int, default=10)
341
+ parser.add_argument("--anneal_steps", type=int, default=150)
342
+ parser.add_argument("--device", type=str, default="mps")
343
+ parser.add_argument("--benchmark_sync", action="store_true")
344
+ args = parser.parse_args()
345
+
346
+ corpus = CharCorpus(make_synthetic_corpus(), args.block_size, args.device)
347
+
348
+ modes =[
349
+ ("dense_baseline", "dense_baseline"),
350
+ ("predicted_magnitude", "sparse_dW_full_dX"),
351
+ ("predicted_magnitude", "sparse_dW_sparse_dX")
352
+ ]
353
+
354
+ print(f"\nModel: {args.n_layer} layers, {args.n_embd} d_model, {args.chunk_size} chunk_size")
355
+ print(f"Batch: {args.batch_size}, Block: {args.block_size}. Target Active Fraction: {args.active_fraction}")
356
+ print(f"Annealing: {args.warmup_steps} warmup steps, {args.anneal_steps} anneal steps.\n")
357
+ print(f"{'Run':>20s} | {'Time (s)':>10s} | {'Step (ms)':>10s} | {'Val Loss':>8s}")
358
+ print("-" * 55)
359
+
360
+ for policy, bwd_mode in modes:
361
+ set_seed(42)
362
+ model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, 0.0).to(args.device)
363
+
364
+ for m in get_sparse_linears(model):
365
+ m.chunk_size = args.chunk_size
366
+
367
+ masker = ChunkMasker(model, policy, args.active_fraction, args.chunk_size, args.device) if policy != "dense_baseline" else None
368
+
369
+ # Restoring the sparse optimizer!
370
+ opt = ChunkedAdam(model, chunk_size=args.chunk_size)
371
+
372
+ if args.benchmark_sync: sync_device(args.device)
373
+
374
+ # We will only measure the time AFTER annealing finishes to get the true steady-state sparse speed.
375
+ t0 = time.perf_counter()
376
+ measured_steps = args.steps
377
+
378
+ for step in range(args.steps):
379
+ # Reset the timer once we hit the target sparsity
380
+ if step == args.warmup_steps + args.anneal_steps:
381
+ if args.benchmark_sync: sync_device(args.device)
382
+ t0 = time.perf_counter()
383
+ measured_steps = args.steps - step
384
+
385
+ x, y = corpus.get_batch("train", args.batch_size, generator=make_cpu_generator(step))
386
+
387
+ if masker:
388
+ masker.choose_active(step, warmup_steps=args.warmup_steps, anneal_steps=args.anneal_steps)
389
+ for m in get_sparse_linears(model):
390
+ m.sparse_enabled = True
391
+ m.sparse_dx = (bwd_mode == "sparse_dW_sparse_dX")
392
+ else:
393
+ for m in get_sparse_linears(model):
394
+ m.sparse_enabled = False
395
+ m.active_chunks = None
396
+
397
+ opt.zero_grad()
398
+ _, loss = model(x, y)
399
+ loss.backward()
400
+
401
+ if masker:
402
+ masker.update_predictor()
403
+
404
+ opt.step()
405
+
406
+ if args.benchmark_sync: sync_device(args.device)
407
+ t_elapsed = time.perf_counter() - t0
408
+
409
+ # Eval loss
410
+ model.eval()
411
+ with torch.no_grad():
412
+ x, y = corpus.get_batch("val", args.batch_size, generator=make_cpu_generator(999))
413
+ _, val_loss = model(x, y)
414
+
415
+ bwd_str = bwd_mode if bwd_mode == "dense_baseline" else ("sparse_full_dX" if "full_dX" in bwd_mode else "sparse_sparse_dX")
416
+ print(f"{bwd_str:>20s} | {t_elapsed:10.2f} | {1000*t_elapsed/max(1, measured_steps):10.2f} | {val_loss.item():8.4f}")
417
+
418
+ if __name__ == "__main__":
419
+ main()
experiments/sparse_linear_v11_gather_vs_metal/tiny.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sparse Transformer: Real-World Benchmark on Tiny Shakespeare using GPT-2 BPE.
3
+
4
+ This script scales the architecture to a 6-layer, 512-dim GPT and trains on
5
+ real natural language. It applies our Hardware-Sympathetic Chunked Sparse
6
+ backward pass, Cosine Annealing, and Chunked Adam optimizer.
7
+
8
+ Run:
9
+ python3 sparse_transformer_shakespeare.py --device mps --benchmark_sync
10
+ """
11
+
12
+ import argparse
13
+ import math
14
+ import os
15
+ import random
16
+ import time
17
+ import urllib.request
18
+ from typing import Dict, List, Literal, Optional, Tuple
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+ try:
25
+ import tiktoken
26
+ except ImportError:
27
+ raise ImportError("Please install tiktoken: pip install tiktoken")
28
+
29
+ torch.set_num_threads(1)
30
+
31
+ def sync_device(device: str) -> None:
32
+ if device == "cuda" and torch.cuda.is_available():
33
+ torch.cuda.synchronize()
34
+ elif device == "mps" and hasattr(torch, "mps"):
35
+ torch.mps.synchronize()
36
+
37
+ Policy = Literal["predicted_magnitude", "oracle_current", "random"]
38
+ BackwardMode = Literal["dense_baseline", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
39
+
40
+ def set_seed(seed: int) -> None:
41
+ random.seed(seed)
42
+ torch.manual_seed(seed)
43
+
44
+ def make_cpu_generator(seed: int) -> torch.Generator:
45
+ gen = torch.Generator(device="cpu")
46
+ gen.manual_seed(seed)
47
+ return gen
48
+
49
+ # -----------------------------
50
+ # Real-World Data Pipeline
51
+ # -----------------------------
52
+ class ShakespeareCorpus:
53
+ def __init__(self, block_size: int, device: str):
54
+ self.block_size = block_size
55
+ self.device = device
56
+
57
+ # 1. Download Tiny Shakespeare if not exists
58
+ data_path = "input.txt"
59
+ if not os.path.exists(data_path):
60
+ print("Downloading Tiny Shakespeare...")
61
+ url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
62
+ urllib.request.urlretrieve(url, data_path)
63
+
64
+ # 2. Tokenize using GPT-2 BPE
65
+ print("Tokenizing data...")
66
+ with open(data_path, "r", encoding="utf-8") as f:
67
+ text = f.read()
68
+
69
+ enc = tiktoken.get_encoding("gpt2")
70
+ tokens = enc.encode(text)
71
+ self.vocab_size = enc.n_vocab
72
+
73
+ # 3. Split 90/10 Train/Val
74
+ data = torch.tensor(tokens, dtype=torch.long)
75
+ split_idx = int(0.9 * len(data))
76
+ self.train_data = data[:split_idx]
77
+ self.val_data = data[split_idx:]
78
+
79
+ print(f"Dataset loaded. Vocab size: {self.vocab_size:,}. Train tokens: {len(self.train_data):,}")
80
+
81
+ def get_batch(self, split: str, batch_size: int, generator: Optional[torch.Generator] = None) -> Tuple[torch.Tensor, torch.Tensor]:
82
+ data = self.train_data if split == "train" else self.val_data
83
+ ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator)
84
+ x = torch.stack([data[i : i + self.block_size] for i in ix])
85
+ y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
86
+ return x.to(self.device), y.to(self.device)
87
+
88
+ # -----------------------------
89
+ # Chunked Sparse Autograd
90
+ # -----------------------------
91
+ class ChunkedMaskedLinear(torch.autograd.Function):
92
+ @staticmethod
93
+ def forward(ctx, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], active_chunks: torch.Tensor, chunk_size: int, sparse_dx: bool) -> torch.Tensor:
94
+ ctx.save_for_backward(x, weight, active_chunks)
95
+ ctx.has_bias = bias is not None
96
+ ctx.sparse_dx = sparse_dx
97
+ ctx.chunk_size = chunk_size
98
+ return F.linear(x, weight, bias)
99
+
100
+ @staticmethod
101
+ def backward(ctx, grad_y: torch.Tensor):
102
+ x, weight, active_chunks = ctx.saved_tensors
103
+ chunk_size = ctx.chunk_size
104
+
105
+ x_flat = x.reshape(-1, x.shape[-1])
106
+ gy_flat = grad_y.reshape(-1, grad_y.shape[-1])
107
+
108
+ grad_w = torch.zeros_like(weight)
109
+ grad_b = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if ctx.has_bias else None
110
+
111
+ if ctx.sparse_dx:
112
+ grad_x_flat = torch.zeros_like(x_flat)
113
+ else:
114
+ grad_x_flat = gy_flat @ weight
115
+
116
+ # Zero-copy Strided Views feeding directly into Dense Hardware Matmuls
117
+ for c_idx in active_chunks.tolist():
118
+ start = c_idx * chunk_size
119
+ end = start + chunk_size
120
+
121
+ gy_slice = gy_flat[:, start:end]
122
+ w_slice = weight[start:end, :]
123
+
124
+ grad_w[start:end, :] = gy_slice.t() @ x_flat
125
+
126
+ if ctx.has_bias:
127
+ grad_b[start:end] = gy_slice.sum(dim=0)
128
+
129
+ if ctx.sparse_dx:
130
+ grad_x_flat += gy_slice @ w_slice
131
+
132
+ return grad_x_flat.reshape(x.shape), grad_w, grad_b, None, None, None
133
+
134
+ class SparseLinear(nn.Linear):
135
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
136
+ super().__init__(in_features, out_features, bias=bias)
137
+ self.sparse_enabled = False
138
+ self.sparse_dx = False
139
+ self.active_chunks: Optional[torch.Tensor] = None
140
+
141
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
142
+ if not self.sparse_enabled or self.active_chunks is None:
143
+ return F.linear(x, self.weight, self.bias)
144
+ return ChunkedMaskedLinear.apply(x, self.weight, self.bias, self.active_chunks, getattr(self, 'chunk_size', 64), self.sparse_dx)
145
+
146
+ # -----------------------------
147
+ # GPT Architecture
148
+ # -----------------------------
149
+ class CausalSelfAttention(nn.Module):
150
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
151
+ super().__init__()
152
+ assert n_embd % n_head == 0
153
+ self.n_head = n_head
154
+ self.head_dim = n_embd // n_head
155
+ self.c_attn = SparseLinear(n_embd, 3 * n_embd)
156
+ self.c_proj = SparseLinear(n_embd, n_embd)
157
+ self.dropout = nn.Dropout(dropout)
158
+ self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
159
+
160
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
161
+ B, T, C = x.shape
162
+ qkv = self.c_attn(x)
163
+ q, k, v = qkv.split(C, dim=2)
164
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
165
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
166
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
167
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
168
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
169
+ att = F.softmax(att, dim=-1)
170
+ att = self.dropout(att)
171
+ y = att @ v
172
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
173
+ return self.c_proj(y)
174
+
175
+ class FeedForward(nn.Module):
176
+ def __init__(self, n_embd: int, dropout: float):
177
+ super().__init__()
178
+ self.c_fc = SparseLinear(n_embd, 4 * n_embd)
179
+ self.c_proj = SparseLinear(4 * n_embd, n_embd)
180
+ self.dropout = nn.Dropout(dropout)
181
+
182
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
183
+ return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
184
+
185
+ class Block(nn.Module):
186
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
187
+ super().__init__()
188
+ self.ln1 = nn.LayerNorm(n_embd)
189
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
190
+ self.ln2 = nn.LayerNorm(n_embd)
191
+ self.mlp = FeedForward(n_embd, dropout)
192
+
193
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
194
+ x = x + self.attn(self.ln1(x))
195
+ x = x + self.mlp(self.ln2(x))
196
+ return x
197
+
198
+ class GPT(nn.Module):
199
+ def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float):
200
+ super().__init__()
201
+ self.block_size = block_size
202
+ self.tok_emb = nn.Embedding(vocab_size, n_embd)
203
+ self.pos_emb = nn.Embedding(block_size, n_embd)
204
+ self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
205
+ self.ln_f = nn.LayerNorm(n_embd)
206
+ # LM head is Dense! Needs full output dist for CrossEntropy loss
207
+ self.lm_head = nn.Linear(n_embd, vocab_size)
208
+
209
+ def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
210
+ B, T = idx.shape
211
+ pos = torch.arange(T, device=idx.device)
212
+ x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
213
+ x = self.blocks(x)
214
+ x = self.ln_f(x)
215
+ logits = self.lm_head(x)
216
+ loss = None
217
+ if targets is not None:
218
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
219
+ return logits, loss
220
+
221
+ def get_sparse_linears(model):
222
+ return[m for m in model.modules() if isinstance(m, SparseLinear)]
223
+
224
+ # -----------------------------
225
+ # Chunk Masker with Annealing
226
+ # -----------------------------
227
+ class ChunkMasker:
228
+ def __init__(self, model: nn.Module, policy: Policy, target_fraction: float, chunk_size: int, device: str):
229
+ self.policy = policy
230
+ self.target_fraction = target_fraction
231
+ self.chunk_size = chunk_size
232
+ self.device = device
233
+
234
+ self.linears = get_sparse_linears(model)
235
+ self.module_to_chunk_ids = {}
236
+ offset = 0
237
+ for m in self.linears:
238
+ assert m.out_features % chunk_size == 0, f"out_features {m.out_features} not divisible by chunk size {chunk_size}"
239
+ n_chunks = m.out_features // chunk_size
240
+ self.module_to_chunk_ids[m] = torch.arange(offset, offset + n_chunks, device=device)
241
+ offset += n_chunks
242
+
243
+ self.n_chunks = offset
244
+ self.predicted_mass = torch.zeros(self.n_chunks, device=device)
245
+ self.active_chunks = torch.zeros(self.n_chunks, dtype=torch.bool, device=device)
246
+
247
+ def choose_active(self, step: int, warmup_steps: int, anneal_steps: int):
248
+ if step < warmup_steps:
249
+ current_fraction = 1.0
250
+ elif step < warmup_steps + anneal_steps:
251
+ progress = (step - warmup_steps) / anneal_steps
252
+ cosine_mult = 0.5 * (1.0 + math.cos(math.pi * progress))
253
+ current_fraction = self.target_fraction + (1.0 - self.target_fraction) * cosine_mult
254
+ else:
255
+ current_fraction = self.target_fraction
256
+
257
+ if current_fraction >= 0.999:
258
+ self.active_chunks.fill_(True)
259
+ for m, ids in self.module_to_chunk_ids.items():
260
+ m.active_chunks = torch.arange(len(ids), device=self.device)
261
+ return
262
+
263
+ k = max(1, int(current_fraction * self.n_chunks))
264
+ self.active_chunks.fill_(False)
265
+
266
+ if self.policy == "random":
267
+ self.active_chunks[torch.randperm(self.n_chunks, device=self.device)[:k]] = True
268
+ elif self.policy == "predicted_magnitude":
269
+ scores = self.predicted_mass + 1e-9 * torch.rand_like(self.predicted_mass)
270
+ self.active_chunks[torch.topk(scores, k=k).indices] = True
271
+
272
+ for m, ids in self.module_to_chunk_ids.items():
273
+ global_active = self.active_chunks[ids]
274
+ local_ids = torch.arange(len(ids), device=self.device)
275
+ m.active_chunks = local_ids[global_active]
276
+
277
+ @torch.no_grad()
278
+ def update_predictor(self, mass_beta=0.95):
279
+ current_mass = torch.zeros_like(self.predicted_mass)
280
+ for m, ids in self.module_to_chunk_ids.items():
281
+ if m.weight.grad is None: continue
282
+ w_sq = m.weight.grad.square().view(len(ids), self.chunk_size, -1).sum(dim=(1, 2))
283
+ if m.bias is not None and m.bias.grad is not None:
284
+ w_sq += m.bias.grad.square().view(len(ids), self.chunk_size).sum(dim=1)
285
+ current_mass[ids] = torch.sqrt(w_sq + 1e-30)
286
+
287
+ observed = self.active_chunks
288
+ self.predicted_mass[observed] = mass_beta * self.predicted_mass[observed] + (1.0 - mass_beta) * current_mass[observed]
289
+
290
+ # -----------------------------
291
+ # Chunked Adam
292
+ # -----------------------------
293
+ class ChunkedAdam:
294
+ def __init__(self, model, lr=5e-4, chunk_size=64):
295
+ self.model = model
296
+ self.lr = lr
297
+ self.chunk_size = chunk_size
298
+ self.state = {}
299
+
300
+ self.param_to_sparse_module = {}
301
+ for m in get_sparse_linears(model):
302
+ if m.weight is not None: self.param_to_sparse_module[m.weight] = m
303
+ if m.bias is not None: self.param_to_sparse_module[m.bias] = m
304
+
305
+ def zero_grad(self):
306
+ for p in self.model.parameters(): p.grad = None
307
+
308
+ @torch.no_grad()
309
+ def step(self):
310
+ for p in self.model.parameters():
311
+ if p.grad is None: continue
312
+ if p not in self.state:
313
+ self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
314
+
315
+ exp_avg, exp_avg_sq = self.state[p]["m"], self.state[p]["v"]
316
+
317
+ sparse_module = self.param_to_sparse_module.get(p)
318
+ active_chunks = getattr(sparse_module, 'active_chunks', None) if sparse_module else None
319
+
320
+ if active_chunks is None:
321
+ # Dense update
322
+ exp_avg.mul_(0.9).add_(p.grad, alpha=0.1)
323
+ exp_avg_sq.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
324
+ update = exp_avg / (torch.sqrt(exp_avg_sq) + 1e-8)
325
+ p.sub_(update, alpha=self.lr)
326
+ else:
327
+ # Sparse update
328
+ for local_c in active_chunks.tolist():
329
+ start = local_c * self.chunk_size
330
+ end = (local_c + 1) * self.chunk_size
331
+
332
+ p_chunk = p[start:end]
333
+ g_chunk = p.grad[start:end]
334
+ m_chunk = exp_avg[start:end]
335
+ v_chunk = exp_avg_sq[start:end]
336
+
337
+ m_chunk.mul_(0.9).add_(g_chunk, alpha=0.1)
338
+ v_chunk.mul_(0.999).addcmul_(g_chunk, g_chunk, value=0.001)
339
+
340
+ update = m_chunk / (torch.sqrt(v_chunk) + 1e-8)
341
+ p_chunk.sub_(update, alpha=self.lr)
342
+
343
+ # -----------------------------
344
+ # Training
345
+ # -----------------------------
346
+ def main():
347
+ parser = argparse.ArgumentParser()
348
+ parser.add_argument("--steps", type=int, default=1000)
349
+ parser.add_argument("--batch_size", type=int, default=8)
350
+ parser.add_argument("--block_size", type=int, default=256)
351
+ parser.add_argument("--n_layer", type=int, default=6)
352
+ parser.add_argument("--n_head", type=int, default=8)
353
+ parser.add_argument("--n_embd", type=int, default=512)
354
+ parser.add_argument("--chunk_size", type=int, default=64)
355
+ parser.add_argument("--active_fraction", type=float, default=0.10)
356
+ parser.add_argument("--warmup_steps", type=int, default=50)
357
+ parser.add_argument("--anneal_steps", type=int, default=200)
358
+ parser.add_argument("--device", type=str, default="mps")
359
+ parser.add_argument("--benchmark_sync", action="store_true")
360
+ args = parser.parse_args()
361
+
362
+ corpus = ShakespeareCorpus(args.block_size, args.device)
363
+
364
+ modes =[
365
+ ("dense_baseline", "dense_baseline"),
366
+ ("predicted_magnitude", "sparse_dW_full_dX"),
367
+ ("predicted_magnitude", "sparse_dW_sparse_dX")
368
+ ]
369
+
370
+ print(f"\nModel: {args.n_layer} layers, {args.n_embd} d_model, {args.chunk_size} chunk_size")
371
+ print(f"Batch: {args.batch_size}, Block: {args.block_size}. Target Active: {args.active_fraction*100}%")
372
+ print(f"Annealing: {args.warmup_steps} warmup steps, {args.anneal_steps} anneal steps.\n")
373
+ print(f"{'Run':>20s} | {'Time (s)':>10s} | {'Step (ms)':>10s} | {'Val Loss':>8s}")
374
+ print("-" * 55)
375
+
376
+ for policy, bwd_mode in modes:
377
+ set_seed(42)
378
+ model = GPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, 0.1).to(args.device)
379
+
380
+ for m in get_sparse_linears(model):
381
+ m.chunk_size = args.chunk_size
382
+
383
+ masker = ChunkMasker(model, policy, args.active_fraction, args.chunk_size, args.device) if policy != "dense_baseline" else None
384
+ opt = ChunkedAdam(model, lr=5e-4, chunk_size=args.chunk_size)
385
+
386
+ if args.benchmark_sync: sync_device(args.device)
387
+
388
+ t0 = time.perf_counter()
389
+ measured_steps = args.steps
390
+
391
+ for step in range(args.steps):
392
+ if step == args.warmup_steps + args.anneal_steps:
393
+ if args.benchmark_sync: sync_device(args.device)
394
+ t0 = time.perf_counter()
395
+ measured_steps = args.steps - step
396
+
397
+ x, y = corpus.get_batch("train", args.batch_size, generator=make_cpu_generator(step))
398
+
399
+ if masker:
400
+ masker.choose_active(step, warmup_steps=args.warmup_steps, anneal_steps=args.anneal_steps)
401
+ for m in get_sparse_linears(model):
402
+ m.sparse_enabled = True
403
+ m.sparse_dx = (bwd_mode == "sparse_dW_sparse_dX")
404
+ else:
405
+ for m in get_sparse_linears(model):
406
+ m.sparse_enabled = False
407
+ m.active_chunks = None
408
+
409
+ opt.zero_grad()
410
+ _, loss = model(x, y)
411
+ loss.backward()
412
+
413
+ if masker:
414
+ masker.update_predictor()
415
+
416
+ opt.step()
417
+
418
+ # Optional: Print progress every 100 steps
419
+ if step % 200 == 0:
420
+ print(f" [Progress] {bwd_mode} step {step}/{args.steps} | Loss: {loss.item():.4f}", end="\r")
421
+
422
+ if args.benchmark_sync: sync_device(args.device)
423
+ t_elapsed = time.perf_counter() - t0
424
+
425
+ # Eval loss
426
+ model.eval()
427
+ with torch.no_grad():
428
+ # Eval loss
429
+ model.eval()
430
+ with torch.no_grad():
431
+ val_x, val_y = corpus.get_batch("val", args.batch_size, generator=make_cpu_generator(999))
432
+ _, val_loss = model(val_x, val_y)
433
+
434
+ # Clear the progress line
435
+ print(" " * 60, end="\r")
436
+
437
+ bwd_str = bwd_mode if bwd_mode == "dense_baseline" else ("sparse_full_dX" if "full_dX" in bwd_mode else "sparse_sparse_dX")
438
+ print(f"{bwd_str:>20s} | {t_elapsed:10.2f} | {1000*t_elapsed/max(1, measured_steps):10.2f} | {val_loss.item():8.4f}")
439
+
440
+ if __name__ == "__main__":
441
+ main()
experiments/sparse_transformer_v15_inactive_prediction.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sparse Transformer v15: Inactive-Update Prediction Diagnostics.
3
+
4
+ Tests two simple ideas:
5
+
6
+ 1. Correlated-neighbor prediction:
7
+ Use active chunks as sensors. For each inactive chunk, find historically
8
+ correlated active chunks and predict its update magnitude from them.
9
+
10
+ 2. Graph / boundary interpolation:
11
+ Treat chunks as nodes in a learned similarity graph. Active chunks are
12
+ boundary values. Inactive chunk magnitudes are filled in by diffusion.
13
+
14
+ This is intentionally a diagnostic script, not a speed benchmark.
15
+ It computes dense gradients every step so we can measure whether inactive
16
+ updates are predictable.
17
+
18
+ Run:
19
+ python3 sparse_transformer_v15_inactive_prediction.py --device mps --benchmark_sync
20
+
21
+ Good first runs:
22
+ python3 sparse_transformer_v15_inactive_prediction.py --device mps --steps 300 --n_embd 512
23
+ python3 sparse_transformer_v15_inactive_prediction.py --device mps --steps 300 --n_embd 1024
24
+ """
25
+
26
+ import argparse
27
+ import math
28
+ import random
29
+ import time
30
+ from typing import Dict, List, Literal, Optional, Tuple
31
+
32
+ import torch
33
+
34
+ torch.set_num_threads(1)
35
+ import torch.nn as nn
36
+ import torch.nn.functional as F
37
+
38
+
39
+ Policy = Literal["predicted_magnitude", "random"]
40
+
41
+
42
+ def sync_device(device: str) -> None:
43
+ if device == "cuda" and torch.cuda.is_available():
44
+ torch.cuda.synchronize()
45
+ elif device == "mps" and hasattr(torch, "mps"):
46
+ torch.mps.synchronize()
47
+
48
+
49
+ def set_seed(seed: int) -> None:
50
+ random.seed(seed)
51
+ torch.manual_seed(seed)
52
+
53
+
54
+ def make_cpu_generator(seed: int) -> torch.Generator:
55
+ gen = torch.Generator(device="cpu")
56
+ gen.manual_seed(seed)
57
+ return gen
58
+
59
+
60
+ # -----------------------------
61
+ # Data
62
+ # -----------------------------
63
+
64
+ def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
65
+ rng = random.Random(seed)
66
+ words = [
67
+ "ada", "turing", "grace", "lovelace", "gradients",
68
+ "tokens", "circuits", "features", "boldly", "strangely",
69
+ "matrix", "attention", "kernel", "entropy", "signal",
70
+ ]
71
+ return "\n".join(
72
+ " ".join(rng.choices(words, k=rng.randint(4, 10))) + "."
73
+ for _ in range(n_sentences)
74
+ )
75
+
76
+
77
+ class CharCorpus:
78
+ def __init__(self, text: str, block_size: int, device: str):
79
+ chars = sorted(set(text))
80
+ self.stoi = {ch: i for i, ch in enumerate(chars)}
81
+ self.itos = {i: ch for ch, i in self.stoi.items()}
82
+ self.vocab_size = len(chars)
83
+ self.block_size = block_size
84
+ self.device = device
85
+
86
+ data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
87
+ self.train_data = data[: int(0.9 * len(data))]
88
+ self.val_data = data[int(0.9 * len(data)) :]
89
+
90
+ def get_batch(
91
+ self,
92
+ split: str,
93
+ batch_size: int,
94
+ generator: Optional[torch.Generator] = None,
95
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
96
+ data = self.train_data if split == "train" else self.val_data
97
+ ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator)
98
+ x = torch.stack([data[i : i + self.block_size] for i in ix])
99
+ y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
100
+ return x.to(self.device), y.to(self.device)
101
+
102
+
103
+ # -----------------------------
104
+ # Model
105
+ # -----------------------------
106
+
107
+ class SparseLinear(nn.Linear):
108
+ """Name retained for compatibility with earlier experiments.
109
+
110
+ In this diagnostic script, backward is dense. We only use chunk masks
111
+ analytically after gradients are computed.
112
+ """
113
+
114
+
115
+ class CausalSelfAttention(nn.Module):
116
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
117
+ super().__init__()
118
+ assert n_embd % n_head == 0
119
+ self.n_head = n_head
120
+ self.head_dim = n_embd // n_head
121
+ self.c_attn = SparseLinear(n_embd, 3 * n_embd)
122
+ self.c_proj = SparseLinear(n_embd, n_embd)
123
+ self.dropout = nn.Dropout(dropout)
124
+ self.register_buffer(
125
+ "mask",
126
+ torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size),
127
+ )
128
+
129
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
130
+ B, T, C = x.shape
131
+ qkv = self.c_attn(x)
132
+ q, k, v = qkv.split(C, dim=2)
133
+
134
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
135
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
136
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
137
+
138
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
139
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
140
+ att = F.softmax(att, dim=-1)
141
+ att = self.dropout(att)
142
+
143
+ y = att @ v
144
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
145
+ return self.c_proj(y)
146
+
147
+
148
+ class FeedForward(nn.Module):
149
+ def __init__(self, n_embd: int, dropout: float):
150
+ super().__init__()
151
+ self.c_fc = SparseLinear(n_embd, 4 * n_embd)
152
+ self.c_proj = SparseLinear(4 * n_embd, n_embd)
153
+ self.dropout = nn.Dropout(dropout)
154
+
155
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
156
+ return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
157
+
158
+
159
+ class Block(nn.Module):
160
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
161
+ super().__init__()
162
+ self.ln1 = nn.LayerNorm(n_embd)
163
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
164
+ self.ln2 = nn.LayerNorm(n_embd)
165
+ self.mlp = FeedForward(n_embd, dropout)
166
+
167
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
168
+ x = x + self.attn(self.ln1(x))
169
+ x = x + self.mlp(self.ln2(x))
170
+ return x
171
+
172
+
173
+ class MiniGPT(nn.Module):
174
+ def __init__(
175
+ self,
176
+ vocab_size: int,
177
+ block_size: int,
178
+ n_layer: int,
179
+ n_head: int,
180
+ n_embd: int,
181
+ dropout: float,
182
+ ):
183
+ super().__init__()
184
+ self.block_size = block_size
185
+ self.tok_emb = nn.Embedding(vocab_size, n_embd)
186
+ self.pos_emb = nn.Embedding(block_size, n_embd)
187
+ self.blocks = nn.Sequential(
188
+ *[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]
189
+ )
190
+ self.ln_f = nn.LayerNorm(n_embd)
191
+ self.lm_head = nn.Linear(n_embd, vocab_size)
192
+
193
+ def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
194
+ B, T = idx.shape
195
+ pos = torch.arange(T, device=idx.device)
196
+ x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
197
+ x = self.blocks(x)
198
+ x = self.ln_f(x)
199
+ logits = self.lm_head(x)
200
+
201
+ loss = None
202
+ if targets is not None:
203
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
204
+ return logits, loss
205
+
206
+
207
+ def get_sparse_linears(model: nn.Module) -> List[SparseLinear]:
208
+ return [m for m in model.modules() if isinstance(m, SparseLinear)]
209
+
210
+
211
+ # -----------------------------
212
+ # Chunk geometry and diagnostics
213
+ # -----------------------------
214
+
215
+ class ChunkMap:
216
+ def __init__(self, model: nn.Module, chunk_size: int, device: str):
217
+ self.model = model
218
+ self.chunk_size = chunk_size
219
+ self.device = device
220
+ self.linears = get_sparse_linears(model)
221
+
222
+ self.module_to_chunk_ids: Dict[nn.Module, torch.Tensor] = {}
223
+ self.chunk_to_module_local: List[Tuple[nn.Module, int]] = []
224
+
225
+ offset = 0
226
+ for m in self.linears:
227
+ assert m.out_features % chunk_size == 0, (
228
+ f"out_features {m.out_features} not divisible by chunk_size {chunk_size}"
229
+ )
230
+ n_chunks = m.out_features // chunk_size
231
+ ids = torch.arange(offset, offset + n_chunks, device=device)
232
+ self.module_to_chunk_ids[m] = ids
233
+ for local_c in range(n_chunks):
234
+ self.chunk_to_module_local.append((m, local_c))
235
+ offset += n_chunks
236
+
237
+ self.n_chunks = offset
238
+ self.predicted_mass = torch.zeros(self.n_chunks, device=device)
239
+ self.direction_ema: List[Optional[torch.Tensor]] = [None for _ in range(self.n_chunks)]
240
+
241
+ # Histories for correlation and graph similarities.
242
+ self.mass_history: List[torch.Tensor] = []
243
+
244
+ def choose_active(
245
+ self,
246
+ step: int,
247
+ warmup_steps: int,
248
+ active_fraction: float,
249
+ policy: Policy,
250
+ ) -> torch.Tensor:
251
+ if step < warmup_steps:
252
+ return torch.ones(self.n_chunks, dtype=torch.bool, device=self.device)
253
+
254
+ k = max(1, int(active_fraction * self.n_chunks))
255
+ mask = torch.zeros(self.n_chunks, dtype=torch.bool, device=self.device)
256
+
257
+ if policy == "random":
258
+ idx = torch.randperm(self.n_chunks, device=self.device)[:k]
259
+ else:
260
+ scores = self.predicted_mass + 1e-9 * torch.rand_like(self.predicted_mass)
261
+ idx = torch.topk(scores, k=k).indices
262
+
263
+ mask[idx] = True
264
+ return mask
265
+
266
+ @torch.no_grad()
267
+ def chunk_gradient_vectors(self) -> List[torch.Tensor]:
268
+ vecs: List[torch.Tensor] = []
269
+ for m, local_c in self.chunk_to_module_local:
270
+ start = local_c * self.chunk_size
271
+ end = (local_c + 1) * self.chunk_size
272
+
273
+ parts = []
274
+ if m.weight.grad is None:
275
+ parts.append(torch.zeros_like(m.weight[start:end]).flatten())
276
+ else:
277
+ parts.append(m.weight.grad[start:end].detach().flatten())
278
+
279
+ if m.bias is not None:
280
+ if m.bias.grad is None:
281
+ parts.append(torch.zeros_like(m.bias[start:end]).flatten())
282
+ else:
283
+ parts.append(m.bias.grad[start:end].detach().flatten())
284
+
285
+ vecs.append(torch.cat(parts))
286
+ return vecs
287
+
288
+ @torch.no_grad()
289
+ def chunk_masses_from_vecs(self, vecs: List[torch.Tensor]) -> torch.Tensor:
290
+ return torch.stack([v.norm() for v in vecs]).to(self.device)
291
+
292
+ @torch.no_grad()
293
+ def update_predictor(
294
+ self,
295
+ active_mask: torch.Tensor,
296
+ vecs: List[torch.Tensor],
297
+ mass_beta: float = 0.95,
298
+ dir_beta: float = 0.95,
299
+ store_history: bool = True,
300
+ ) -> torch.Tensor:
301
+ masses = self.chunk_masses_from_vecs(vecs)
302
+
303
+ observed = active_mask
304
+ # First observation should initialize directly, not get shrunk by beta.
305
+ never_seen = observed & (self.predicted_mass == 0)
306
+ already_seen = observed & ~never_seen
307
+ self.predicted_mass[never_seen] = masses[never_seen]
308
+ self.predicted_mass[already_seen] = (
309
+ mass_beta * self.predicted_mass[already_seen]
310
+ + (1.0 - mass_beta) * masses[already_seen]
311
+ )
312
+
313
+ for i, is_active in enumerate(observed.tolist()):
314
+ if not is_active:
315
+ continue
316
+ v = vecs[i]
317
+ n = v.norm()
318
+ if n <= 1e-12:
319
+ continue
320
+ unit = v / n
321
+ if self.direction_ema[i] is None:
322
+ self.direction_ema[i] = unit.detach().clone()
323
+ else:
324
+ self.direction_ema[i] = (
325
+ dir_beta * self.direction_ema[i] + (1.0 - dir_beta) * unit
326
+ )
327
+ self.direction_ema[i] = self.direction_ema[i] / (self.direction_ema[i].norm() + 1e-12)
328
+
329
+ if store_history:
330
+ self.mass_history.append(masses.detach().clone())
331
+ max_hist = 128
332
+ if len(self.mass_history) > max_hist:
333
+ self.mass_history = self.mass_history[-max_hist:]
334
+
335
+ return masses
336
+
337
+ def layer_aware_masks(self) -> List[torch.Tensor]:
338
+ masks = []
339
+ for m, ids in self.module_to_chunk_ids.items():
340
+ mask = torch.zeros(self.n_chunks, dtype=torch.bool, device=self.device)
341
+ mask[ids] = True
342
+ masks.append(mask)
343
+ return masks
344
+
345
+
346
+ def dense_cosine_from_vecs(a: List[torch.Tensor], b: List[torch.Tensor]) -> float:
347
+ va = torch.cat([x.flatten() for x in a])
348
+ vb = torch.cat([x.flatten() for x in b])
349
+ return float(F.cosine_similarity(va, vb, dim=0).item())
350
+
351
+
352
+ def mse_reduction_vs_zero(true_vecs: List[torch.Tensor], pred_vecs: List[torch.Tensor], mask: torch.Tensor) -> float:
353
+ idxs = torch.nonzero(mask, as_tuple=False).flatten().tolist()
354
+ if not idxs:
355
+ return float("nan")
356
+ true = torch.cat([true_vecs[i].flatten() for i in idxs])
357
+ pred = torch.cat([pred_vecs[i].flatten() for i in idxs])
358
+ zero_mse = torch.mean(true.square())
359
+ pred_mse = torch.mean((true - pred).square())
360
+ return float((1.0 - pred_mse / (zero_mse + 1e-12)).item())
361
+
362
+
363
+ def active_only_prediction(true_vecs: List[torch.Tensor], active_mask: torch.Tensor) -> List[torch.Tensor]:
364
+ out = []
365
+ for i, v in enumerate(true_vecs):
366
+ out.append(v.clone() if bool(active_mask[i]) else torch.zeros_like(v))
367
+ return out
368
+
369
+
370
+ def ema_direction_prediction(
371
+ cmap: ChunkMap,
372
+ true_vecs: List[torch.Tensor],
373
+ active_mask: torch.Tensor,
374
+ inactive_magnitudes: torch.Tensor,
375
+ ) -> List[torch.Tensor]:
376
+ out = []
377
+ for i, v in enumerate(true_vecs):
378
+ if bool(active_mask[i]):
379
+ out.append(v.clone())
380
+ else:
381
+ direction = cmap.direction_ema[i]
382
+ if direction is None:
383
+ out.append(torch.zeros_like(v))
384
+ else:
385
+ out.append(direction.to(v.device, v.dtype) * inactive_magnitudes[i])
386
+ return out
387
+
388
+
389
+ def build_mass_similarity(cmap: ChunkMap, min_history: int = 8) -> Optional[torch.Tensor]:
390
+ if len(cmap.mass_history) < min_history:
391
+ return None
392
+
393
+ H = torch.stack(cmap.mass_history, dim=0) # [history, chunks]
394
+ H = H - H.mean(dim=0, keepdim=True)
395
+ H = H / (H.std(dim=0, keepdim=True) + 1e-6)
396
+
397
+ S = (H.T @ H) / max(1, H.shape[0] - 1)
398
+ S = torch.clamp(S, min=0.0)
399
+
400
+ # Remove self similarity.
401
+ S.fill_diagonal_(0.0)
402
+
403
+ # Layer-aware block diagonal: avoid mixing unrelated layers by default.
404
+ layer_masks = cmap.layer_aware_masks()
405
+ layer_allowed = torch.zeros_like(S, dtype=torch.bool)
406
+ for mask in layer_masks:
407
+ layer_allowed |= mask[:, None] & mask[None, :]
408
+ S = torch.where(layer_allowed, S, torch.zeros_like(S))
409
+
410
+ return S
411
+
412
+
413
+ def knn_magnitude_prediction(
414
+ cmap: ChunkMap,
415
+ active_mask: torch.Tensor,
416
+ true_masses: torch.Tensor,
417
+ k_neighbors: int = 3,
418
+ ) -> torch.Tensor:
419
+ """Predict inactive magnitudes as weighted average of correlated active magnitudes."""
420
+ S = build_mass_similarity(cmap)
421
+ if S is None:
422
+ pred = cmap.predicted_mass.clone()
423
+ pred[active_mask] = true_masses[active_mask]
424
+ return pred
425
+
426
+ pred = torch.zeros_like(true_masses)
427
+ pred[active_mask] = true_masses[active_mask]
428
+
429
+ active_idx = torch.nonzero(active_mask, as_tuple=False).flatten()
430
+ inactive_idx = torch.nonzero(~active_mask, as_tuple=False).flatten()
431
+
432
+ if active_idx.numel() == 0:
433
+ return pred
434
+
435
+ for i in inactive_idx.tolist():
436
+ weights = S[i, active_idx]
437
+ if weights.sum() <= 1e-12:
438
+ pred[i] = cmap.predicted_mass[i]
439
+ continue
440
+
441
+ kk = min(k_neighbors, weights.numel())
442
+ top = torch.topk(weights, k=kk)
443
+ w = top.values
444
+ aidx = active_idx[top.indices]
445
+ pred[i] = (w * true_masses[aidx]).sum() / (w.sum() + 1e-12)
446
+
447
+ return pred
448
+
449
+
450
+ def graph_diffusion_magnitude_prediction(
451
+ cmap: ChunkMap,
452
+ active_mask: torch.Tensor,
453
+ true_masses: torch.Tensor,
454
+ diffusion_steps: int = 8,
455
+ alpha: float = 0.7,
456
+ ) -> torch.Tensor:
457
+ """Boundary-value style magnitude interpolation over a learned similarity graph.
458
+
459
+ Active nodes are clamped to observed true magnitudes. Inactive nodes diffuse
460
+ toward graph-neighbor values.
461
+ """
462
+ S = build_mass_similarity(cmap)
463
+ if S is None:
464
+ pred = cmap.predicted_mass.clone()
465
+ pred[active_mask] = true_masses[active_mask]
466
+ return pred
467
+
468
+ W = S / (S.sum(dim=1, keepdim=True) + 1e-12)
469
+
470
+ pred = cmap.predicted_mass.clone()
471
+ pred[active_mask] = true_masses[active_mask]
472
+
473
+ for _ in range(diffusion_steps):
474
+ proposal = W @ pred
475
+ pred = alpha * proposal + (1.0 - alpha) * pred
476
+ pred[active_mask] = true_masses[active_mask]
477
+
478
+ return torch.clamp(pred, min=0.0)
479
+
480
+
481
+ # -----------------------------
482
+ # Optimizer
483
+ # -----------------------------
484
+
485
+ class SimpleAdam:
486
+ """Small Adam-like optimizer for diagnostics.
487
+
488
+ This is intentionally simple and consistent across runs. It is not trying
489
+ to be production AdamW.
490
+ """
491
+
492
+ def __init__(self, model: nn.Module, lr: float = 3e-4):
493
+ self.model = model
494
+ self.lr = lr
495
+ self.state: Dict[torch.nn.Parameter, Dict[str, torch.Tensor]] = {}
496
+
497
+ def zero_grad(self):
498
+ for p in self.model.parameters():
499
+ p.grad = None
500
+
501
+ @torch.no_grad()
502
+ def step(self):
503
+ for p in self.model.parameters():
504
+ if p.grad is None:
505
+ continue
506
+ if p not in self.state:
507
+ self.state[p] = {
508
+ "m": torch.zeros_like(p),
509
+ "v": torch.zeros_like(p),
510
+ }
511
+ m = self.state[p]["m"]
512
+ v = self.state[p]["v"]
513
+ m.mul_(0.9).add_(p.grad, alpha=0.1)
514
+ v.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
515
+ p.sub_(m / (torch.sqrt(v) + 1e-8), alpha=self.lr)
516
+
517
+
518
+ # -----------------------------
519
+ # Apply chunk-gradient predictions
520
+ # -----------------------------
521
+
522
+ @torch.no_grad()
523
+ def install_chunk_prediction_as_grads(
524
+ cmap: ChunkMap,
525
+ pred_vecs: List[torch.Tensor],
526
+ ):
527
+ """Overwrite SparseLinear weight/bias grads from predicted chunk vectors.
528
+
529
+ Non-SparseLinear parameters keep their dense gradients.
530
+ """
531
+ for m, ids in cmap.module_to_chunk_ids.items():
532
+ if m.weight.grad is None:
533
+ continue
534
+ m.weight.grad.zero_()
535
+ if m.bias is not None and m.bias.grad is not None:
536
+ m.bias.grad.zero_()
537
+
538
+ for local_c, global_id in enumerate(ids.tolist()):
539
+ start = local_c * cmap.chunk_size
540
+ end = (local_c + 1) * cmap.chunk_size
541
+
542
+ v = pred_vecs[global_id]
543
+ w_numel = cmap.chunk_size * m.weight.shape[1]
544
+ w_flat = v[:w_numel]
545
+ m.weight.grad[start:end] = w_flat.view(cmap.chunk_size, m.weight.shape[1])
546
+
547
+ if m.bias is not None and m.bias.grad is not None:
548
+ b_flat = v[w_numel:]
549
+ if b_flat.numel() > 0:
550
+ m.bias.grad[start:end] = b_flat.view(cmap.chunk_size)
551
+
552
+
553
+ # -----------------------------
554
+ # Training / diagnostics
555
+ # -----------------------------
556
+
557
+ def evaluate(model: nn.Module, corpus: CharCorpus, batch_size: int, seed: int) -> float:
558
+ model.eval()
559
+ with torch.no_grad():
560
+ x, y = corpus.get_batch("val", batch_size, generator=make_cpu_generator(seed))
561
+ _, loss = model(x, y)
562
+ model.train()
563
+ return float(loss.item())
564
+
565
+
566
+ def run_experiment(
567
+ mode: str,
568
+ device: str,
569
+ steps: int,
570
+ batch_size: int,
571
+ block_size: int,
572
+ n_layer: int,
573
+ n_head: int,
574
+ n_embd: int,
575
+ chunk_size: int,
576
+ active_fraction: float,
577
+ warmup_steps: int,
578
+ policy: Policy,
579
+ benchmark_sync: bool,
580
+ ) -> Dict[str, float]:
581
+ set_seed(42)
582
+ corpus = CharCorpus(make_synthetic_corpus(), block_size, device)
583
+ model = MiniGPT(corpus.vocab_size, block_size, n_layer, n_head, n_embd, 0.0).to(device)
584
+ opt = SimpleAdam(model, lr=3e-4)
585
+ cmap = ChunkMap(model, chunk_size=chunk_size, device=device)
586
+
587
+ metric_rows = []
588
+
589
+ if benchmark_sync:
590
+ sync_device(device)
591
+ t0 = time.perf_counter()
592
+
593
+ for step in range(steps):
594
+ x, y = corpus.get_batch("train", batch_size, generator=make_cpu_generator(step))
595
+
596
+ opt.zero_grad()
597
+ _, loss = model(x, y)
598
+ loss.backward()
599
+
600
+ true_vecs = cmap.chunk_gradient_vectors()
601
+ true_masses = cmap.chunk_masses_from_vecs(true_vecs)
602
+
603
+ active_mask = cmap.choose_active(
604
+ step=step,
605
+ warmup_steps=warmup_steps,
606
+ active_fraction=active_fraction,
607
+ policy=policy,
608
+ )
609
+
610
+ if step < warmup_steps or mode == "dense":
611
+ pred_vecs = [v.clone() for v in true_vecs]
612
+ else:
613
+ active_only_vecs = active_only_prediction(true_vecs, active_mask)
614
+
615
+ if mode == "active_only":
616
+ pred_vecs = active_only_vecs
617
+
618
+ elif mode == "knn_magnitude":
619
+ pred_masses = knn_magnitude_prediction(cmap, active_mask, true_masses)
620
+ pred_vecs = ema_direction_prediction(cmap, true_vecs, active_mask, pred_masses)
621
+
622
+ elif mode == "graph_diffusion":
623
+ pred_masses = graph_diffusion_magnitude_prediction(cmap, active_mask, true_masses)
624
+ pred_vecs = ema_direction_prediction(cmap, true_vecs, active_mask, pred_masses)
625
+
626
+ elif mode == "ema_inactive":
627
+ pred_masses = cmap.predicted_mass.clone()
628
+ pred_masses[active_mask] = true_masses[active_mask]
629
+ pred_vecs = ema_direction_prediction(cmap, true_vecs, active_mask, pred_masses)
630
+
631
+ else:
632
+ raise ValueError(f"Unknown mode: {mode}")
633
+
634
+ install_chunk_prediction_as_grads(cmap, pred_vecs)
635
+
636
+ if step % 25 == 0:
637
+ inactive_mask = ~active_mask
638
+ row = {
639
+ "cosine_full": dense_cosine_from_vecs(true_vecs, pred_vecs),
640
+ "inactive_mse_reduction": mse_reduction_vs_zero(true_vecs, pred_vecs, inactive_mask),
641
+ "active_frac": float(active_mask.float().mean().item()),
642
+ "val": evaluate(model, corpus, batch_size, seed=999 + step),
643
+ }
644
+ metric_rows.append(row)
645
+
646
+ # Update predictor after measuring and installing predicted grads.
647
+ # Use true active chunk observations only, mimicking sparse observation.
648
+ cmap.update_predictor(active_mask, true_vecs, store_history=True)
649
+
650
+ opt.step()
651
+
652
+ if benchmark_sync:
653
+ sync_device(device)
654
+ elapsed = time.perf_counter() - t0
655
+
656
+ val_loss = evaluate(model, corpus, batch_size, seed=12345)
657
+
658
+ if metric_rows:
659
+ avg_cos = sum(r["cosine_full"] for r in metric_rows) / len(metric_rows)
660
+ avg_mse_red = sum(r["inactive_mse_reduction"] for r in metric_rows) / len(metric_rows)
661
+ else:
662
+ avg_cos = float("nan")
663
+ avg_mse_red = float("nan")
664
+
665
+ return {
666
+ "val": val_loss,
667
+ "ms": 1000.0 * elapsed / steps,
668
+ "cos": avg_cos,
669
+ "mse_red": avg_mse_red,
670
+ }
671
+
672
+
673
+ def main():
674
+ parser = argparse.ArgumentParser()
675
+ parser.add_argument("--steps", type=int, default=300)
676
+ parser.add_argument("--batch_size", type=int, default=8)
677
+ parser.add_argument("--block_size", type=int, default=128)
678
+ parser.add_argument("--n_layer", type=int, default=4)
679
+ parser.add_argument("--n_head", type=int, default=8)
680
+ parser.add_argument("--n_embd", type=int, default=512)
681
+ parser.add_argument("--chunk_size", type=int, default=64)
682
+ parser.add_argument("--active_fraction", type=float, default=0.10)
683
+ parser.add_argument("--warmup_steps", type=int, default=25)
684
+ parser.add_argument("--policy", type=str, default="predicted_magnitude", choices=["predicted_magnitude", "random"])
685
+ parser.add_argument("--device", type=str, default="mps")
686
+ parser.add_argument("--benchmark_sync", action="store_true")
687
+ args = parser.parse_args()
688
+
689
+ modes = [
690
+ "dense",
691
+ "active_only",
692
+ "ema_inactive",
693
+ "knn_magnitude",
694
+ "graph_diffusion",
695
+ ]
696
+
697
+ print(f"\nInactive-update prediction diagnostic")
698
+ print(f"device={args.device} steps={args.steps} d={args.n_embd} chunks={args.chunk_size}")
699
+ print(f"active_fraction={args.active_fraction} warmup={args.warmup_steps} policy={args.policy}\n")
700
+ print(f"{'mode':>18s} | {'val':>8s} | {'ms/step':>8s} | {'grad_cos':>8s} | {'inactive_mse+':>13s}")
701
+ print("-" * 70)
702
+
703
+ for mode in modes:
704
+ result = run_experiment(
705
+ mode=mode,
706
+ device=args.device,
707
+ steps=args.steps,
708
+ batch_size=args.batch_size,
709
+ block_size=args.block_size,
710
+ n_layer=args.n_layer,
711
+ n_head=args.n_head,
712
+ n_embd=args.n_embd,
713
+ chunk_size=args.chunk_size,
714
+ active_fraction=args.active_fraction,
715
+ warmup_steps=args.warmup_steps,
716
+ policy=args.policy,
717
+ benchmark_sync=args.benchmark_sync,
718
+ )
719
+ print(
720
+ f"{mode:>18s} | "
721
+ f"{result['val']:8.4f} | "
722
+ f"{result['ms']:8.2f} | "
723
+ f"{result['cos']:8.3f} | "
724
+ f"{result['mse_red']:13.3f}"
725
+ )
726
+
727
+
728
+ if __name__ == "__main__":
729
+ main()
experiments/sparse_transformer_v16_sensor_scheduler.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sparse Transformer v16: Sensor-Based Mask Scheduling.
3
+
4
+ v15 showed that directly hallucinating inactive gradient vectors was harmful.
5
+ v16 tests the safer next idea:
6
+
7
+ Use active chunks as sensors to choose which chunks receive real gradients next.
8
+
9
+ No inactive gradient is invented. In sparse modes, inactive chunks get zero gradient.
10
+ The only question is whether active chunk observations improve future mask selection.
11
+
12
+ Schedulers:
13
+ dense
14
+ Dense baseline.
15
+
16
+ ema_topk
17
+ Select top chunks by each chunk's own EMA gradient mass.
18
+
19
+ knn_scheduler
20
+ Use active chunks as sensors. Predict next-step inactive chunk mass from
21
+ historically correlated active chunks. Select next mask from that score.
22
+
23
+ graph_scheduler
24
+ Boundary-value style magnitude diffusion over a chunk similarity graph.
25
+ Active chunks are clamped to observed magnitudes. Inactive magnitudes are
26
+ interpolated and used to choose the next mask.
27
+
28
+ random
29
+ Random sparse-support control.
30
+
31
+ This is still a diagnostic/simulation script: it computes dense gradients so we can
32
+ measure oracle Jaccard/cosine, then installs only the selected active chunk gradients
33
+ for sparse training.
34
+
35
+ Run:
36
+ python3 sparse_transformer_v16_sensor_scheduler.py --device mps --benchmark_sync
37
+
38
+ Useful:
39
+ python3 sparse_transformer_v16_sensor_scheduler.py --device mps --steps 500 --n_embd 512
40
+ python3 sparse_transformer_v16_sensor_scheduler.py --device mps --steps 500 --n_embd 1024
41
+ """
42
+
43
+ from __future__ import annotations
44
+
45
+ import argparse
46
+ import math
47
+ import random
48
+ import time
49
+ from typing import Dict, List, Literal, Optional, Tuple
50
+
51
+ import torch
52
+
53
+ torch.set_num_threads(1)
54
+ import torch.nn as nn
55
+ import torch.nn.functional as F
56
+
57
+ Scheduler = Literal["dense", "ema_topk", "knn_scheduler", "graph_scheduler", "random"]
58
+
59
+
60
+ def sync_device(device: str) -> None:
61
+ if device == "cuda" and torch.cuda.is_available():
62
+ torch.cuda.synchronize()
63
+ elif device == "mps" and hasattr(torch, "mps"):
64
+ torch.mps.synchronize()
65
+
66
+
67
+ def set_seed(seed: int) -> None:
68
+ random.seed(seed)
69
+ torch.manual_seed(seed)
70
+
71
+
72
+ def make_cpu_generator(seed: int) -> torch.Generator:
73
+ gen = torch.Generator(device="cpu")
74
+ gen.manual_seed(seed)
75
+ return gen
76
+
77
+
78
+ # -----------------------------
79
+ # Data
80
+ # -----------------------------
81
+
82
+ def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
83
+ rng = random.Random(seed)
84
+ words = [
85
+ "ada", "turing", "grace", "lovelace", "gradients",
86
+ "tokens", "circuits", "features", "boldly", "strangely",
87
+ "matrix", "attention", "kernel", "entropy", "signal",
88
+ ]
89
+ return "\n".join(
90
+ " ".join(rng.choices(words, k=rng.randint(4, 10))) + "."
91
+ for _ in range(n_sentences)
92
+ )
93
+
94
+
95
+ class CharCorpus:
96
+ def __init__(self, text: str, block_size: int, device: str):
97
+ chars = sorted(set(text))
98
+ self.stoi = {ch: i for i, ch in enumerate(chars)}
99
+ self.itos = {i: ch for ch, i in self.stoi.items()}
100
+ self.vocab_size = len(chars)
101
+ self.block_size = block_size
102
+ self.device = device
103
+ data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
104
+ self.train_data = data[: int(0.9 * len(data))]
105
+ self.val_data = data[int(0.9 * len(data)) :]
106
+
107
+ def get_batch(
108
+ self,
109
+ split: str,
110
+ batch_size: int,
111
+ generator: Optional[torch.Generator] = None,
112
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
113
+ data = self.train_data if split == "train" else self.val_data
114
+ ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator)
115
+ x = torch.stack([data[i : i + self.block_size] for i in ix])
116
+ y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
117
+ return x.to(self.device), y.to(self.device)
118
+
119
+
120
+ # -----------------------------
121
+ # Model
122
+ # -----------------------------
123
+
124
+ class SparseLinear(nn.Linear):
125
+ pass
126
+
127
+
128
+ class CausalSelfAttention(nn.Module):
129
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
130
+ super().__init__()
131
+ assert n_embd % n_head == 0
132
+ self.n_head = n_head
133
+ self.head_dim = n_embd // n_head
134
+ self.c_attn = SparseLinear(n_embd, 3 * n_embd)
135
+ self.c_proj = SparseLinear(n_embd, n_embd)
136
+ self.dropout = nn.Dropout(dropout)
137
+ self.register_buffer(
138
+ "mask",
139
+ torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size),
140
+ )
141
+
142
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
143
+ B, T, C = x.shape
144
+ qkv = self.c_attn(x)
145
+ q, k, v = qkv.split(C, dim=2)
146
+
147
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
148
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
149
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
150
+
151
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
152
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
153
+ att = F.softmax(att, dim=-1)
154
+ att = self.dropout(att)
155
+
156
+ y = att @ v
157
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
158
+ return self.c_proj(y)
159
+
160
+
161
+ class FeedForward(nn.Module):
162
+ def __init__(self, n_embd: int, dropout: float):
163
+ super().__init__()
164
+ self.c_fc = SparseLinear(n_embd, 4 * n_embd)
165
+ self.c_proj = SparseLinear(4 * n_embd, n_embd)
166
+ self.dropout = nn.Dropout(dropout)
167
+
168
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
169
+ return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
170
+
171
+
172
+ class Block(nn.Module):
173
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
174
+ super().__init__()
175
+ self.ln1 = nn.LayerNorm(n_embd)
176
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
177
+ self.ln2 = nn.LayerNorm(n_embd)
178
+ self.mlp = FeedForward(n_embd, dropout)
179
+
180
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
181
+ x = x + self.attn(self.ln1(x))
182
+ x = x + self.mlp(self.ln2(x))
183
+ return x
184
+
185
+
186
+ class MiniGPT(nn.Module):
187
+ def __init__(
188
+ self,
189
+ vocab_size: int,
190
+ block_size: int,
191
+ n_layer: int,
192
+ n_head: int,
193
+ n_embd: int,
194
+ dropout: float,
195
+ ):
196
+ super().__init__()
197
+ self.block_size = block_size
198
+ self.tok_emb = nn.Embedding(vocab_size, n_embd)
199
+ self.pos_emb = nn.Embedding(block_size, n_embd)
200
+ self.blocks = nn.Sequential(
201
+ *[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]
202
+ )
203
+ self.ln_f = nn.LayerNorm(n_embd)
204
+ self.lm_head = nn.Linear(n_embd, vocab_size)
205
+
206
+ def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
207
+ B, T = idx.shape
208
+ pos = torch.arange(T, device=idx.device)
209
+ x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
210
+ x = self.blocks(x)
211
+ x = self.ln_f(x)
212
+ logits = self.lm_head(x)
213
+
214
+ loss = None
215
+ if targets is not None:
216
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
217
+ return logits, loss
218
+
219
+
220
+ def get_sparse_linears(model: nn.Module) -> List[SparseLinear]:
221
+ return [m for m in model.modules() if isinstance(m, SparseLinear)]
222
+
223
+
224
+ # -----------------------------
225
+ # Chunk map and scheduler
226
+ # -----------------------------
227
+
228
+ class ChunkScheduler:
229
+ def __init__(
230
+ self,
231
+ model: nn.Module,
232
+ chunk_size: int,
233
+ active_fraction: float,
234
+ device: str,
235
+ scheduler: Scheduler,
236
+ mass_beta: float = 0.95,
237
+ ):
238
+ self.model = model
239
+ self.chunk_size = chunk_size
240
+ self.active_fraction = active_fraction
241
+ self.device = device
242
+ self.scheduler = scheduler
243
+ self.mass_beta = mass_beta
244
+
245
+ self.linears = get_sparse_linears(model)
246
+ self.module_to_chunk_ids: Dict[nn.Module, torch.Tensor] = {}
247
+ self.chunk_to_module_local: List[Tuple[nn.Module, int]] = []
248
+
249
+ offset = 0
250
+ for m in self.linears:
251
+ assert m.out_features % chunk_size == 0, (
252
+ f"out_features {m.out_features} not divisible by chunk_size {chunk_size}"
253
+ )
254
+ n_chunks = m.out_features // chunk_size
255
+ ids = torch.arange(offset, offset + n_chunks, device=device)
256
+ self.module_to_chunk_ids[m] = ids
257
+ for local_c in range(n_chunks):
258
+ self.chunk_to_module_local.append((m, local_c))
259
+ offset += n_chunks
260
+
261
+ self.n_chunks = offset
262
+ self.predicted_mass = torch.zeros(self.n_chunks, device=device)
263
+ self.mass_history: List[torch.Tensor] = []
264
+
265
+ self.current_mask = torch.ones(self.n_chunks, dtype=torch.bool, device=device)
266
+ self.next_scores = torch.zeros(self.n_chunks, device=device)
267
+
268
+ self.prev_mask: Optional[torch.Tensor] = None
269
+ self.similarity: Optional[torch.Tensor] = None
270
+
271
+ def k_active(self) -> int:
272
+ return max(1, int(self.active_fraction * self.n_chunks))
273
+
274
+ def choose_mask(self, step: int, warmup_steps: int) -> torch.Tensor:
275
+ if self.scheduler == "dense" or step < warmup_steps:
276
+ self.current_mask = torch.ones(self.n_chunks, dtype=torch.bool, device=self.device)
277
+ return self.current_mask
278
+
279
+ k = self.k_active()
280
+ mask = torch.zeros(self.n_chunks, dtype=torch.bool, device=self.device)
281
+
282
+ if self.scheduler == "random":
283
+ idx = torch.randperm(self.n_chunks, device=self.device)[:k]
284
+
285
+ elif self.scheduler == "ema_topk":
286
+ scores = self.predicted_mass + 1e-9 * torch.rand_like(self.predicted_mass)
287
+ idx = torch.topk(scores, k=k).indices
288
+
289
+ elif self.scheduler in ("knn_scheduler", "graph_scheduler"):
290
+ # next_scores are computed from the previous step's active sensors.
291
+ # If unavailable, fall back to EMA.
292
+ base = self.next_scores
293
+ if torch.count_nonzero(base).item() == 0:
294
+ base = self.predicted_mass
295
+ scores = base + 1e-9 * torch.rand_like(base)
296
+ idx = torch.topk(scores, k=k).indices
297
+
298
+ else:
299
+ raise ValueError(f"Unknown scheduler: {self.scheduler}")
300
+
301
+ mask[idx] = True
302
+ self.current_mask = mask
303
+ return mask
304
+
305
+ @torch.no_grad()
306
+ def chunk_gradient_vectors(self) -> List[torch.Tensor]:
307
+ vecs: List[torch.Tensor] = []
308
+ for m, local_c in self.chunk_to_module_local:
309
+ start = local_c * self.chunk_size
310
+ end = (local_c + 1) * self.chunk_size
311
+
312
+ parts = []
313
+ if m.weight.grad is None:
314
+ parts.append(torch.zeros_like(m.weight[start:end]).flatten())
315
+ else:
316
+ parts.append(m.weight.grad[start:end].detach().flatten())
317
+
318
+ if m.bias is not None:
319
+ if m.bias.grad is None:
320
+ parts.append(torch.zeros_like(m.bias[start:end]).flatten())
321
+ else:
322
+ parts.append(m.bias.grad[start:end].detach().flatten())
323
+
324
+ vecs.append(torch.cat(parts))
325
+ return vecs
326
+
327
+ @torch.no_grad()
328
+ def chunk_masses_from_vecs(self, vecs: List[torch.Tensor]) -> torch.Tensor:
329
+ return torch.stack([v.norm() for v in vecs]).to(self.device)
330
+
331
+ @torch.no_grad()
332
+ def update_from_observed(
333
+ self,
334
+ active_mask: torch.Tensor,
335
+ true_masses: torch.Tensor,
336
+ step: int,
337
+ warmup_steps: int,
338
+ ) -> None:
339
+ observed = active_mask
340
+
341
+ never_seen = observed & (self.predicted_mass == 0)
342
+ already_seen = observed & ~never_seen
343
+
344
+ self.predicted_mass[never_seen] = true_masses[never_seen]
345
+ self.predicted_mass[already_seen] = (
346
+ self.mass_beta * self.predicted_mass[already_seen]
347
+ + (1.0 - self.mass_beta) * true_masses[already_seen]
348
+ )
349
+
350
+ # During warmup we store dense mass histories to learn the similarity graph.
351
+ if step < warmup_steps:
352
+ self.mass_history.append(true_masses.detach().clone())
353
+ max_hist = 128
354
+ if len(self.mass_history) > max_hist:
355
+ self.mass_history = self.mass_history[-max_hist:]
356
+
357
+ if len(self.mass_history) >= 8:
358
+ self.similarity = self.build_similarity()
359
+
360
+ # Compute next_scores from current active observations.
361
+ if self.scheduler == "knn_scheduler":
362
+ self.next_scores = self.knn_scores(active_mask, true_masses)
363
+ elif self.scheduler == "graph_scheduler":
364
+ self.next_scores = self.diffusion_scores(active_mask, true_masses)
365
+ else:
366
+ self.next_scores = self.predicted_mass.clone()
367
+
368
+ def layer_allowed_mask(self) -> torch.Tensor:
369
+ allowed = torch.zeros((self.n_chunks, self.n_chunks), dtype=torch.bool, device=self.device)
370
+ for _, ids in self.module_to_chunk_ids.items():
371
+ allowed |= ids[:, None].eq(ids[None, :]) # placeholder overwritten below
372
+
373
+ allowed.zero_()
374
+ for _, ids in self.module_to_chunk_ids.items():
375
+ allowed[ids[:, None], ids[None, :]] = True
376
+ return allowed
377
+
378
+ def build_similarity(self) -> torch.Tensor:
379
+ H = torch.stack(self.mass_history, dim=0) # [history, chunks]
380
+ H = H - H.mean(dim=0, keepdim=True)
381
+ H = H / (H.std(dim=0, keepdim=True) + 1e-6)
382
+
383
+ S = (H.T @ H) / max(1, H.shape[0] - 1)
384
+ S = torch.clamp(S, min=0.0)
385
+ S.fill_diagonal_(0.0)
386
+
387
+ # Keep only within-layer similarities. Cross-layer correlation is too easy
388
+ # to overfit in this tiny diagnostic.
389
+ allowed = torch.zeros_like(S, dtype=torch.bool)
390
+ for _, ids in self.module_to_chunk_ids.items():
391
+ allowed[ids[:, None], ids[None, :]] = True
392
+ S = torch.where(allowed, S, torch.zeros_like(S))
393
+ return S
394
+
395
+ def knn_scores(self, active_mask: torch.Tensor, true_masses: torch.Tensor, k_neighbors: int = 3) -> torch.Tensor:
396
+ if self.similarity is None:
397
+ return self.predicted_mass.clone()
398
+
399
+ S = self.similarity
400
+ scores = self.predicted_mass.clone()
401
+ scores[active_mask] = true_masses[active_mask]
402
+
403
+ active_idx = torch.nonzero(active_mask, as_tuple=False).flatten()
404
+ inactive_idx = torch.nonzero(~active_mask, as_tuple=False).flatten()
405
+
406
+ if active_idx.numel() == 0:
407
+ return scores
408
+
409
+ for i in inactive_idx.tolist():
410
+ weights = S[i, active_idx]
411
+ if weights.sum() <= 1e-12:
412
+ continue
413
+ kk = min(k_neighbors, weights.numel())
414
+ top = torch.topk(weights, k=kk)
415
+ w = top.values
416
+ aidx = active_idx[top.indices]
417
+ scores[i] = (w * true_masses[aidx]).sum() / (w.sum() + 1e-12)
418
+
419
+ return scores
420
+
421
+ def diffusion_scores(
422
+ self,
423
+ active_mask: torch.Tensor,
424
+ true_masses: torch.Tensor,
425
+ diffusion_steps: int = 8,
426
+ alpha: float = 0.7,
427
+ ) -> torch.Tensor:
428
+ if self.similarity is None:
429
+ return self.predicted_mass.clone()
430
+
431
+ S = self.similarity
432
+ W = S / (S.sum(dim=1, keepdim=True) + 1e-12)
433
+
434
+ scores = self.predicted_mass.clone()
435
+ scores[active_mask] = true_masses[active_mask]
436
+
437
+ for _ in range(diffusion_steps):
438
+ proposal = W @ scores
439
+ scores = alpha * proposal + (1.0 - alpha) * scores
440
+ scores[active_mask] = true_masses[active_mask]
441
+
442
+ return torch.clamp(scores, min=0.0)
443
+
444
+ def oracle_topk_mask(self, true_masses: torch.Tensor) -> torch.Tensor:
445
+ k = self.k_active()
446
+ mask = torch.zeros(self.n_chunks, dtype=torch.bool, device=self.device)
447
+ mask[torch.topk(true_masses, k=k).indices] = True
448
+ return mask
449
+
450
+
451
+ # -----------------------------
452
+ # Gradient installation and metrics
453
+ # -----------------------------
454
+
455
+ @torch.no_grad()
456
+ def install_active_only_grads(sched: ChunkScheduler, active_mask: torch.Tensor) -> None:
457
+ if sched.scheduler == "dense":
458
+ return
459
+
460
+ for m, ids in sched.module_to_chunk_ids.items():
461
+ local_active = active_mask[ids]
462
+ if m.weight.grad is not None:
463
+ for local_c, is_active in enumerate(local_active.tolist()):
464
+ if not is_active:
465
+ start = local_c * sched.chunk_size
466
+ end = (local_c + 1) * sched.chunk_size
467
+ m.weight.grad[start:end].zero_()
468
+
469
+ if m.bias is not None and m.bias.grad is not None:
470
+ for local_c, is_active in enumerate(local_active.tolist()):
471
+ if not is_active:
472
+ start = local_c * sched.chunk_size
473
+ end = (local_c + 1) * sched.chunk_size
474
+ m.bias.grad[start:end].zero_()
475
+
476
+
477
+ def dense_cosine_active_only(vecs: List[torch.Tensor], active_mask: torch.Tensor) -> float:
478
+ true = torch.cat([v.flatten() for v in vecs])
479
+ approx_parts = []
480
+ for i, v in enumerate(vecs):
481
+ approx_parts.append(v.flatten() if bool(active_mask[i]) else torch.zeros_like(v).flatten())
482
+ approx = torch.cat(approx_parts)
483
+ return float(F.cosine_similarity(true, approx, dim=0).item())
484
+
485
+
486
+ def jaccard(a: torch.Tensor, b: torch.Tensor) -> float:
487
+ inter = (a & b).sum().float()
488
+ union = (a | b).sum().float()
489
+ return float((inter / torch.clamp(union, min=1.0)).item())
490
+
491
+
492
+ class SimpleAdam:
493
+ def __init__(self, model: nn.Module, lr: float = 3e-4):
494
+ self.model = model
495
+ self.lr = lr
496
+ self.state: Dict[torch.nn.Parameter, Dict[str, torch.Tensor]] = {}
497
+
498
+ def zero_grad(self):
499
+ for p in self.model.parameters():
500
+ p.grad = None
501
+
502
+ @torch.no_grad()
503
+ def step(self):
504
+ for p in self.model.parameters():
505
+ if p.grad is None:
506
+ continue
507
+ if p not in self.state:
508
+ self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
509
+ m = self.state[p]["m"]
510
+ v = self.state[p]["v"]
511
+ m.mul_(0.9).add_(p.grad, alpha=0.1)
512
+ v.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
513
+ p.sub_(m / (torch.sqrt(v) + 1e-8), alpha=self.lr)
514
+
515
+
516
+ def evaluate(model: nn.Module, corpus: CharCorpus, batch_size: int, seed: int) -> float:
517
+ model.eval()
518
+ with torch.no_grad():
519
+ x, y = corpus.get_batch("val", batch_size, generator=make_cpu_generator(seed))
520
+ _, loss = model(x, y)
521
+ model.train()
522
+ return float(loss.item())
523
+
524
+
525
+ def run_experiment(
526
+ scheduler_name: Scheduler,
527
+ device: str,
528
+ steps: int,
529
+ batch_size: int,
530
+ block_size: int,
531
+ n_layer: int,
532
+ n_head: int,
533
+ n_embd: int,
534
+ chunk_size: int,
535
+ active_fraction: float,
536
+ warmup_steps: int,
537
+ benchmark_sync: bool,
538
+ ) -> Dict[str, float]:
539
+ set_seed(42)
540
+
541
+ corpus = CharCorpus(make_synthetic_corpus(), block_size, device)
542
+ model = MiniGPT(corpus.vocab_size, block_size, n_layer, n_head, n_embd, 0.0).to(device)
543
+ opt = SimpleAdam(model, lr=3e-4)
544
+ sched = ChunkScheduler(
545
+ model=model,
546
+ chunk_size=chunk_size,
547
+ active_fraction=active_fraction,
548
+ device=device,
549
+ scheduler=scheduler_name,
550
+ )
551
+
552
+ metric_rows = []
553
+
554
+ if benchmark_sync:
555
+ sync_device(device)
556
+ t0 = time.perf_counter()
557
+
558
+ for step in range(steps):
559
+ x, y = corpus.get_batch("train", batch_size, generator=make_cpu_generator(step))
560
+
561
+ active_mask = sched.choose_mask(step=step, warmup_steps=warmup_steps)
562
+
563
+ opt.zero_grad()
564
+ _, loss = model(x, y)
565
+ loss.backward()
566
+
567
+ vecs = sched.chunk_gradient_vectors()
568
+ masses = sched.chunk_masses_from_vecs(vecs)
569
+
570
+ if step >= warmup_steps and scheduler_name != "dense":
571
+ oracle = sched.oracle_topk_mask(masses)
572
+ row = {
573
+ "cos": dense_cosine_active_only(vecs, active_mask),
574
+ "jacc": jaccard(active_mask, oracle),
575
+ "stable": jaccard(active_mask, sched.prev_mask) if sched.prev_mask is not None else 0.0,
576
+ "val": evaluate(model, corpus, batch_size, seed=10_000 + step) if step % 50 == 0 else float("nan"),
577
+ }
578
+ metric_rows.append(row)
579
+
580
+ install_active_only_grads(sched, active_mask)
581
+
582
+ # Important: update scheduler from the active observations only.
583
+ # Dense gradients exist for diagnostics, but unselected chunks should not
584
+ # teach the sparse scheduler after warmup.
585
+ observed_for_scheduler = active_mask if step >= warmup_steps else torch.ones_like(active_mask)
586
+ sched.update_from_observed(
587
+ active_mask=observed_for_scheduler,
588
+ true_masses=masses,
589
+ step=step,
590
+ warmup_steps=warmup_steps,
591
+ )
592
+
593
+ sched.prev_mask = active_mask.clone()
594
+
595
+ opt.step()
596
+
597
+ if benchmark_sync:
598
+ sync_device(device)
599
+ elapsed = time.perf_counter() - t0
600
+
601
+ val_loss = evaluate(model, corpus, batch_size, seed=12345)
602
+
603
+ if metric_rows:
604
+ avg_cos = sum(r["cos"] for r in metric_rows) / len(metric_rows)
605
+ avg_jacc = sum(r["jacc"] for r in metric_rows) / len(metric_rows)
606
+ avg_stable = sum(r["stable"] for r in metric_rows) / len(metric_rows)
607
+ else:
608
+ avg_cos = float("nan")
609
+ avg_jacc = float("nan")
610
+ avg_stable = float("nan")
611
+
612
+ return {
613
+ "val": val_loss,
614
+ "ms": 1000.0 * elapsed / steps,
615
+ "cos": avg_cos,
616
+ "jacc": avg_jacc,
617
+ "stable": avg_stable,
618
+ }
619
+
620
+
621
+ def main() -> None:
622
+ parser = argparse.ArgumentParser()
623
+ parser.add_argument("--steps", type=int, default=500)
624
+ parser.add_argument("--batch_size", type=int, default=8)
625
+ parser.add_argument("--block_size", type=int, default=128)
626
+ parser.add_argument("--n_layer", type=int, default=4)
627
+ parser.add_argument("--n_head", type=int, default=8)
628
+ parser.add_argument("--n_embd", type=int, default=512)
629
+ parser.add_argument("--chunk_size", type=int, default=64)
630
+ parser.add_argument("--active_fraction", type=float, default=0.10)
631
+ parser.add_argument("--warmup_steps", type=int, default=25)
632
+ parser.add_argument("--device", type=str, default="mps")
633
+ parser.add_argument("--benchmark_sync", action="store_true")
634
+ args = parser.parse_args()
635
+
636
+ schedulers: List[Scheduler] = [
637
+ "dense",
638
+ "ema_topk",
639
+ "knn_scheduler",
640
+ "graph_scheduler",
641
+ "random",
642
+ ]
643
+
644
+ print("\nSensor-based mask scheduling diagnostic")
645
+ print(f"device={args.device} steps={args.steps} d={args.n_embd} chunks={args.chunk_size}")
646
+ print(f"active_fraction={args.active_fraction} warmup={args.warmup_steps}\n")
647
+ print(f"{'scheduler':>18s} | {'val':>8s} | {'ms/step':>8s} | {'grad_cos':>8s} | {'jacc':>8s} | {'stable':>8s}")
648
+ print("-" * 78)
649
+
650
+ for sched_name in schedulers:
651
+ result = run_experiment(
652
+ scheduler_name=sched_name,
653
+ device=args.device,
654
+ steps=args.steps,
655
+ batch_size=args.batch_size,
656
+ block_size=args.block_size,
657
+ n_layer=args.n_layer,
658
+ n_head=args.n_head,
659
+ n_embd=args.n_embd,
660
+ chunk_size=args.chunk_size,
661
+ active_fraction=args.active_fraction,
662
+ warmup_steps=args.warmup_steps,
663
+ benchmark_sync=args.benchmark_sync,
664
+ )
665
+
666
+ print(
667
+ f"{sched_name:>18s} | "
668
+ f"{result['val']:8.4f} | "
669
+ f"{result['ms']:8.2f} | "
670
+ f"{result['cos']:8.3f} | "
671
+ f"{result['jacc']:8.3f} | "
672
+ f"{result['stable']:8.3f}"
673
+ )
674
+
675
+
676
+ if __name__ == "__main__":
677
+ main()
experiments/sparse_transformer_v17_radar_scheduler.py ADDED
@@ -0,0 +1,725 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sparse Transformer v17: Radar Scheduler Diagnostic.
3
+
4
+ v16 showed:
5
+ - Directly predicting inactive gradient vectors is harmful.
6
+ - Using active chunks as sensors to schedule the next active mask works.
7
+ - KNN/graph sensors improve oracle overlap and gradient cosine, but churn masks.
8
+ - EMA is stable, but can be blind.
9
+
10
+ v17 tests the fusion:
11
+
12
+ radar_score = alpha * normalized_ema_mass
13
+ + (1 - alpha) * normalized_sensor_score
14
+
15
+ where sensor_score is either:
16
+ - KNN over a learned chunk-mass correlation graph
17
+ - graph diffusion / boundary interpolation over that graph
18
+
19
+ This is still a diagnostic script. It computes dense gradients so we can measure:
20
+ - oracle Jaccard
21
+ - active-only full-gradient cosine
22
+ - mask stability
23
+ - validation loss after sparse active-only updates
24
+
25
+ No inactive gradients are invented. In sparse modes, inactive chunks get zeroed.
26
+
27
+ Run:
28
+ python3 sparse_transformer_v17_radar_scheduler.py --device mps --benchmark_sync
29
+
30
+ Useful:
31
+ python3 sparse_transformer_v17_radar_scheduler.py --device mps --steps 500 --n_embd 512
32
+ python3 sparse_transformer_v17_radar_scheduler.py --device mps --steps 500 --n_embd 1024
33
+ python3 sparse_transformer_v17_radar_scheduler.py --device mps --steps 500 --n_embd 1024 --alphas 0.25 0.5 0.75 0.9
34
+ """
35
+
36
+ from __future__ import annotations
37
+
38
+ import argparse
39
+ import math
40
+ import random
41
+ import time
42
+ from typing import Dict, List, Literal, Optional, Tuple
43
+
44
+ import torch
45
+
46
+ torch.set_num_threads(1)
47
+ import torch.nn as nn
48
+ import torch.nn.functional as F
49
+
50
+ Scheduler = Literal[
51
+ "dense",
52
+ "ema_topk",
53
+ "knn_scheduler",
54
+ "graph_scheduler",
55
+ "radar_knn",
56
+ "radar_graph",
57
+ "random",
58
+ ]
59
+
60
+
61
+ def sync_device(device: str) -> None:
62
+ if device == "cuda" and torch.cuda.is_available():
63
+ torch.cuda.synchronize()
64
+ elif device == "mps" and hasattr(torch, "mps"):
65
+ torch.mps.synchronize()
66
+
67
+
68
+ def set_seed(seed: int) -> None:
69
+ random.seed(seed)
70
+ torch.manual_seed(seed)
71
+
72
+
73
+ def make_cpu_generator(seed: int) -> torch.Generator:
74
+ gen = torch.Generator(device="cpu")
75
+ gen.manual_seed(seed)
76
+ return gen
77
+
78
+
79
+ def normalize_scores(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
80
+ """Robust [0, 1] normalization.
81
+
82
+ We avoid z-score because heavy tails are the signal, not necessarily noise.
83
+ """
84
+ x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
85
+ lo = x.min()
86
+ hi = x.max()
87
+ if (hi - lo) <= eps:
88
+ return torch.zeros_like(x)
89
+ return (x - lo) / (hi - lo + eps)
90
+
91
+
92
+ # -----------------------------
93
+ # Data
94
+ # -----------------------------
95
+
96
+ def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
97
+ rng = random.Random(seed)
98
+ words = [
99
+ "ada", "turing", "grace", "lovelace", "gradients",
100
+ "tokens", "circuits", "features", "boldly", "strangely",
101
+ "matrix", "attention", "kernel", "entropy", "signal",
102
+ ]
103
+ return "\n".join(
104
+ " ".join(rng.choices(words, k=rng.randint(4, 10))) + "."
105
+ for _ in range(n_sentences)
106
+ )
107
+
108
+
109
+ class CharCorpus:
110
+ def __init__(self, text: str, block_size: int, device: str):
111
+ chars = sorted(set(text))
112
+ self.stoi = {ch: i for i, ch in enumerate(chars)}
113
+ self.itos = {i: ch for ch, i in self.stoi.items()}
114
+ self.vocab_size = len(chars)
115
+ self.block_size = block_size
116
+ self.device = device
117
+ data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
118
+ self.train_data = data[: int(0.9 * len(data))]
119
+ self.val_data = data[int(0.9 * len(data)) :]
120
+
121
+ def get_batch(
122
+ self,
123
+ split: str,
124
+ batch_size: int,
125
+ generator: Optional[torch.Generator] = None,
126
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
127
+ data = self.train_data if split == "train" else self.val_data
128
+ ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator)
129
+ x = torch.stack([data[i : i + self.block_size] for i in ix])
130
+ y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
131
+ return x.to(self.device), y.to(self.device)
132
+
133
+
134
+ # -----------------------------
135
+ # Model
136
+ # -----------------------------
137
+
138
+ class SparseLinear(nn.Linear):
139
+ pass
140
+
141
+
142
+ class CausalSelfAttention(nn.Module):
143
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
144
+ super().__init__()
145
+ assert n_embd % n_head == 0
146
+ self.n_head = n_head
147
+ self.head_dim = n_embd // n_head
148
+ self.c_attn = SparseLinear(n_embd, 3 * n_embd)
149
+ self.c_proj = SparseLinear(n_embd, n_embd)
150
+ self.dropout = nn.Dropout(dropout)
151
+ self.register_buffer(
152
+ "mask",
153
+ torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size),
154
+ )
155
+
156
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
157
+ B, T, C = x.shape
158
+ qkv = self.c_attn(x)
159
+ q, k, v = qkv.split(C, dim=2)
160
+
161
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
162
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
163
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
164
+
165
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
166
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
167
+ att = F.softmax(att, dim=-1)
168
+ att = self.dropout(att)
169
+
170
+ y = att @ v
171
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
172
+ return self.c_proj(y)
173
+
174
+
175
+ class FeedForward(nn.Module):
176
+ def __init__(self, n_embd: int, dropout: float):
177
+ super().__init__()
178
+ self.c_fc = SparseLinear(n_embd, 4 * n_embd)
179
+ self.c_proj = SparseLinear(4 * n_embd, n_embd)
180
+ self.dropout = nn.Dropout(dropout)
181
+
182
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
183
+ return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
184
+
185
+
186
+ class Block(nn.Module):
187
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
188
+ super().__init__()
189
+ self.ln1 = nn.LayerNorm(n_embd)
190
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
191
+ self.ln2 = nn.LayerNorm(n_embd)
192
+ self.mlp = FeedForward(n_embd, dropout)
193
+
194
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
195
+ x = x + self.attn(self.ln1(x))
196
+ x = x + self.mlp(self.ln2(x))
197
+ return x
198
+
199
+
200
+ class MiniGPT(nn.Module):
201
+ def __init__(
202
+ self,
203
+ vocab_size: int,
204
+ block_size: int,
205
+ n_layer: int,
206
+ n_head: int,
207
+ n_embd: int,
208
+ dropout: float,
209
+ ):
210
+ super().__init__()
211
+ self.block_size = block_size
212
+ self.tok_emb = nn.Embedding(vocab_size, n_embd)
213
+ self.pos_emb = nn.Embedding(block_size, n_embd)
214
+ self.blocks = nn.Sequential(
215
+ *[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]
216
+ )
217
+ self.ln_f = nn.LayerNorm(n_embd)
218
+ self.lm_head = nn.Linear(n_embd, vocab_size)
219
+
220
+ def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
221
+ B, T = idx.shape
222
+ pos = torch.arange(T, device=idx.device)
223
+ x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
224
+ x = self.blocks(x)
225
+ x = self.ln_f(x)
226
+ logits = self.lm_head(x)
227
+
228
+ loss = None
229
+ if targets is not None:
230
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
231
+ return logits, loss
232
+
233
+
234
+ def get_sparse_linears(model: nn.Module) -> List[SparseLinear]:
235
+ return [m for m in model.modules() if isinstance(m, SparseLinear)]
236
+
237
+
238
+ # -----------------------------
239
+ # Radar scheduler
240
+ # -----------------------------
241
+
242
+ class RadarScheduler:
243
+ def __init__(
244
+ self,
245
+ model: nn.Module,
246
+ chunk_size: int,
247
+ active_fraction: float,
248
+ device: str,
249
+ scheduler: Scheduler,
250
+ alpha: float,
251
+ mass_beta: float = 0.95,
252
+ similarity_history: int = 128,
253
+ min_similarity_history: int = 8,
254
+ ):
255
+ self.model = model
256
+ self.chunk_size = chunk_size
257
+ self.active_fraction = active_fraction
258
+ self.device = device
259
+ self.scheduler = scheduler
260
+ self.alpha = float(alpha)
261
+ self.mass_beta = mass_beta
262
+ self.similarity_history = similarity_history
263
+ self.min_similarity_history = min_similarity_history
264
+
265
+ self.linears = get_sparse_linears(model)
266
+ self.module_to_chunk_ids: Dict[nn.Module, torch.Tensor] = {}
267
+ self.chunk_to_module_local: List[Tuple[nn.Module, int]] = []
268
+
269
+ offset = 0
270
+ for m in self.linears:
271
+ assert m.out_features % chunk_size == 0, (
272
+ f"out_features {m.out_features} not divisible by chunk_size {chunk_size}"
273
+ )
274
+ n_chunks = m.out_features // chunk_size
275
+ ids = torch.arange(offset, offset + n_chunks, device=device)
276
+ self.module_to_chunk_ids[m] = ids
277
+ for local_c in range(n_chunks):
278
+ self.chunk_to_module_local.append((m, local_c))
279
+ offset += n_chunks
280
+
281
+ self.n_chunks = offset
282
+ self.predicted_mass = torch.zeros(self.n_chunks, device=device)
283
+ self.mass_history: List[torch.Tensor] = []
284
+
285
+ self.current_mask = torch.ones(self.n_chunks, dtype=torch.bool, device=device)
286
+ self.next_sensor_scores = torch.zeros(self.n_chunks, device=device)
287
+ self.next_scores = torch.zeros(self.n_chunks, device=device)
288
+ self.prev_mask: Optional[torch.Tensor] = None
289
+ self.similarity: Optional[torch.Tensor] = None
290
+
291
+ def k_active(self) -> int:
292
+ return max(1, int(self.active_fraction * self.n_chunks))
293
+
294
+ def choose_mask(self, step: int, warmup_steps: int) -> torch.Tensor:
295
+ if self.scheduler == "dense" or step < warmup_steps:
296
+ self.current_mask = torch.ones(self.n_chunks, dtype=torch.bool, device=self.device)
297
+ return self.current_mask
298
+
299
+ k = self.k_active()
300
+ mask = torch.zeros(self.n_chunks, dtype=torch.bool, device=self.device)
301
+
302
+ if self.scheduler == "random":
303
+ idx = torch.randperm(self.n_chunks, device=self.device)[:k]
304
+ else:
305
+ scores = self.score_for_selection()
306
+ scores = scores + 1e-9 * torch.rand_like(scores)
307
+ idx = torch.topk(scores, k=k).indices
308
+
309
+ mask[idx] = True
310
+ self.current_mask = mask
311
+ return mask
312
+
313
+ def score_for_selection(self) -> torch.Tensor:
314
+ if self.scheduler == "random":
315
+ return torch.zeros_like(self.predicted_mass)
316
+
317
+ if self.scheduler == "ema_topk":
318
+ return self.predicted_mass
319
+
320
+ if self.scheduler in ("knn_scheduler", "graph_scheduler"):
321
+ if torch.count_nonzero(self.next_sensor_scores).item() == 0:
322
+ return self.predicted_mass
323
+ return self.next_sensor_scores
324
+
325
+ if self.scheduler in ("radar_knn", "radar_graph"):
326
+ sensor = self.next_sensor_scores
327
+ if torch.count_nonzero(sensor).item() == 0:
328
+ sensor = self.predicted_mass
329
+
330
+ ema_n = normalize_scores(self.predicted_mass)
331
+ sensor_n = normalize_scores(sensor)
332
+ return self.alpha * ema_n + (1.0 - self.alpha) * sensor_n
333
+
334
+ if self.scheduler == "dense":
335
+ return torch.ones_like(self.predicted_mass)
336
+
337
+ raise ValueError(f"Unknown scheduler: {self.scheduler}")
338
+
339
+ @torch.no_grad()
340
+ def chunk_gradient_vectors(self) -> List[torch.Tensor]:
341
+ vecs: List[torch.Tensor] = []
342
+ for m, local_c in self.chunk_to_module_local:
343
+ start = local_c * self.chunk_size
344
+ end = (local_c + 1) * self.chunk_size
345
+
346
+ parts = []
347
+ if m.weight.grad is None:
348
+ parts.append(torch.zeros_like(m.weight[start:end]).flatten())
349
+ else:
350
+ parts.append(m.weight.grad[start:end].detach().flatten())
351
+
352
+ if m.bias is not None:
353
+ if m.bias.grad is None:
354
+ parts.append(torch.zeros_like(m.bias[start:end]).flatten())
355
+ else:
356
+ parts.append(m.bias.grad[start:end].detach().flatten())
357
+
358
+ vecs.append(torch.cat(parts))
359
+ return vecs
360
+
361
+ @torch.no_grad()
362
+ def chunk_masses_from_vecs(self, vecs: List[torch.Tensor]) -> torch.Tensor:
363
+ return torch.stack([v.norm() for v in vecs]).to(self.device)
364
+
365
+ @torch.no_grad()
366
+ def update_from_observed(
367
+ self,
368
+ observed_mask: torch.Tensor,
369
+ true_masses: torch.Tensor,
370
+ step: int,
371
+ warmup_steps: int,
372
+ ) -> None:
373
+ observed = observed_mask
374
+
375
+ never_seen = observed & (self.predicted_mass == 0)
376
+ already_seen = observed & ~never_seen
377
+
378
+ self.predicted_mass[never_seen] = true_masses[never_seen]
379
+ self.predicted_mass[already_seen] = (
380
+ self.mass_beta * self.predicted_mass[already_seen]
381
+ + (1.0 - self.mass_beta) * true_masses[already_seen]
382
+ )
383
+
384
+ # Dense warmup teaches the similarity graph.
385
+ if step < warmup_steps:
386
+ self.mass_history.append(true_masses.detach().clone())
387
+ if len(self.mass_history) > self.similarity_history:
388
+ self.mass_history = self.mass_history[-self.similarity_history :]
389
+
390
+ if len(self.mass_history) >= self.min_similarity_history:
391
+ self.similarity = self.build_similarity()
392
+
393
+ if self.scheduler in ("knn_scheduler", "radar_knn"):
394
+ self.next_sensor_scores = self.knn_scores(observed, true_masses)
395
+ elif self.scheduler in ("graph_scheduler", "radar_graph"):
396
+ self.next_sensor_scores = self.diffusion_scores(observed, true_masses)
397
+ else:
398
+ self.next_sensor_scores = self.predicted_mass.clone()
399
+
400
+ self.next_scores = self.score_for_selection()
401
+
402
+ def build_similarity(self) -> torch.Tensor:
403
+ H = torch.stack(self.mass_history, dim=0) # [history, chunks]
404
+ H = H - H.mean(dim=0, keepdim=True)
405
+ H = H / (H.std(dim=0, keepdim=True) + 1e-6)
406
+
407
+ S = (H.T @ H) / max(1, H.shape[0] - 1)
408
+ S = torch.clamp(S, min=0.0)
409
+ S.fill_diagonal_(0.0)
410
+
411
+ # Within-layer only. This keeps the graph interpretable and avoids
412
+ # overfitting tiny cross-layer coincidences.
413
+ allowed = torch.zeros_like(S, dtype=torch.bool)
414
+ for _, ids in self.module_to_chunk_ids.items():
415
+ allowed[ids[:, None], ids[None, :]] = True
416
+
417
+ S = torch.where(allowed, S, torch.zeros_like(S))
418
+ return S
419
+
420
+ def knn_scores(
421
+ self,
422
+ active_mask: torch.Tensor,
423
+ true_masses: torch.Tensor,
424
+ k_neighbors: int = 3,
425
+ ) -> torch.Tensor:
426
+ if self.similarity is None:
427
+ return self.predicted_mass.clone()
428
+
429
+ S = self.similarity
430
+ scores = self.predicted_mass.clone()
431
+ scores[active_mask] = true_masses[active_mask]
432
+
433
+ active_idx = torch.nonzero(active_mask, as_tuple=False).flatten()
434
+ inactive_idx = torch.nonzero(~active_mask, as_tuple=False).flatten()
435
+
436
+ if active_idx.numel() == 0:
437
+ return scores
438
+
439
+ for i in inactive_idx.tolist():
440
+ weights = S[i, active_idx]
441
+ if weights.sum() <= 1e-12:
442
+ continue
443
+ kk = min(k_neighbors, weights.numel())
444
+ top = torch.topk(weights, k=kk)
445
+ w = top.values
446
+ aidx = active_idx[top.indices]
447
+ scores[i] = (w * true_masses[aidx]).sum() / (w.sum() + 1e-12)
448
+
449
+ return scores
450
+
451
+ def diffusion_scores(
452
+ self,
453
+ active_mask: torch.Tensor,
454
+ true_masses: torch.Tensor,
455
+ diffusion_steps: int = 8,
456
+ alpha: float = 0.7,
457
+ ) -> torch.Tensor:
458
+ if self.similarity is None:
459
+ return self.predicted_mass.clone()
460
+
461
+ S = self.similarity
462
+ W = S / (S.sum(dim=1, keepdim=True) + 1e-12)
463
+
464
+ scores = self.predicted_mass.clone()
465
+ scores[active_mask] = true_masses[active_mask]
466
+
467
+ for _ in range(diffusion_steps):
468
+ proposal = W @ scores
469
+ scores = alpha * proposal + (1.0 - alpha) * scores
470
+ scores[active_mask] = true_masses[active_mask]
471
+
472
+ return torch.clamp(scores, min=0.0)
473
+
474
+ def oracle_topk_mask(self, true_masses: torch.Tensor) -> torch.Tensor:
475
+ k = self.k_active()
476
+ mask = torch.zeros(self.n_chunks, dtype=torch.bool, device=self.device)
477
+ mask[torch.topk(true_masses, k=k).indices] = True
478
+ return mask
479
+
480
+
481
+ # -----------------------------
482
+ # Gradient installation and metrics
483
+ # -----------------------------
484
+
485
+ @torch.no_grad()
486
+ def install_active_only_grads(sched: RadarScheduler, active_mask: torch.Tensor) -> None:
487
+ if sched.scheduler == "dense":
488
+ return
489
+
490
+ for m, ids in sched.module_to_chunk_ids.items():
491
+ local_active = active_mask[ids]
492
+
493
+ if m.weight.grad is not None:
494
+ for local_c, is_active in enumerate(local_active.tolist()):
495
+ if not is_active:
496
+ start = local_c * sched.chunk_size
497
+ end = (local_c + 1) * sched.chunk_size
498
+ m.weight.grad[start:end].zero_()
499
+
500
+ if m.bias is not None and m.bias.grad is not None:
501
+ for local_c, is_active in enumerate(local_active.tolist()):
502
+ if not is_active:
503
+ start = local_c * sched.chunk_size
504
+ end = (local_c + 1) * sched.chunk_size
505
+ m.bias.grad[start:end].zero_()
506
+
507
+
508
+ def dense_cosine_active_only(vecs: List[torch.Tensor], active_mask: torch.Tensor) -> float:
509
+ true = torch.cat([v.flatten() for v in vecs])
510
+ approx_parts = []
511
+ for i, v in enumerate(vecs):
512
+ approx_parts.append(v.flatten() if bool(active_mask[i]) else torch.zeros_like(v).flatten())
513
+ approx = torch.cat(approx_parts)
514
+ return float(F.cosine_similarity(true, approx, dim=0).item())
515
+
516
+
517
+ def jaccard(a: torch.Tensor, b: torch.Tensor) -> float:
518
+ inter = (a & b).sum().float()
519
+ union = (a | b).sum().float()
520
+ return float((inter / torch.clamp(union, min=1.0)).item())
521
+
522
+
523
+ class SimpleAdam:
524
+ def __init__(self, model: nn.Module, lr: float = 3e-4):
525
+ self.model = model
526
+ self.lr = lr
527
+ self.state: Dict[torch.nn.Parameter, Dict[str, torch.Tensor]] = {}
528
+
529
+ def zero_grad(self):
530
+ for p in self.model.parameters():
531
+ p.grad = None
532
+
533
+ @torch.no_grad()
534
+ def step(self):
535
+ for p in self.model.parameters():
536
+ if p.grad is None:
537
+ continue
538
+ if p not in self.state:
539
+ self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
540
+ m = self.state[p]["m"]
541
+ v = self.state[p]["v"]
542
+ m.mul_(0.9).add_(p.grad, alpha=0.1)
543
+ v.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
544
+ p.sub_(m / (torch.sqrt(v) + 1e-8), alpha=self.lr)
545
+
546
+
547
+ def evaluate(model: nn.Module, corpus: CharCorpus, batch_size: int, seed: int) -> float:
548
+ model.eval()
549
+ with torch.no_grad():
550
+ x, y = corpus.get_batch("val", batch_size, generator=make_cpu_generator(seed))
551
+ _, loss = model(x, y)
552
+ model.train()
553
+ return float(loss.item())
554
+
555
+
556
+ def run_experiment(
557
+ scheduler_name: Scheduler,
558
+ alpha: float,
559
+ device: str,
560
+ steps: int,
561
+ batch_size: int,
562
+ block_size: int,
563
+ n_layer: int,
564
+ n_head: int,
565
+ n_embd: int,
566
+ chunk_size: int,
567
+ active_fraction: float,
568
+ warmup_steps: int,
569
+ benchmark_sync: bool,
570
+ ) -> Dict[str, float]:
571
+ set_seed(42)
572
+
573
+ corpus = CharCorpus(make_synthetic_corpus(), block_size, device)
574
+ model = MiniGPT(corpus.vocab_size, block_size, n_layer, n_head, n_embd, 0.0).to(device)
575
+ opt = SimpleAdam(model, lr=3e-4)
576
+ sched = RadarScheduler(
577
+ model=model,
578
+ chunk_size=chunk_size,
579
+ active_fraction=active_fraction,
580
+ device=device,
581
+ scheduler=scheduler_name,
582
+ alpha=alpha,
583
+ )
584
+
585
+ metric_rows = []
586
+
587
+ if benchmark_sync:
588
+ sync_device(device)
589
+ t0 = time.perf_counter()
590
+
591
+ for step in range(steps):
592
+ x, y = corpus.get_batch("train", batch_size, generator=make_cpu_generator(step))
593
+
594
+ active_mask = sched.choose_mask(step=step, warmup_steps=warmup_steps)
595
+
596
+ opt.zero_grad()
597
+ _, loss = model(x, y)
598
+ loss.backward()
599
+
600
+ vecs = sched.chunk_gradient_vectors()
601
+ masses = sched.chunk_masses_from_vecs(vecs)
602
+
603
+ if step >= warmup_steps and scheduler_name != "dense":
604
+ oracle = sched.oracle_topk_mask(masses)
605
+ metric_rows.append(
606
+ {
607
+ "cos": dense_cosine_active_only(vecs, active_mask),
608
+ "jacc": jaccard(active_mask, oracle),
609
+ "stable": jaccard(active_mask, sched.prev_mask) if sched.prev_mask is not None else 0.0,
610
+ }
611
+ )
612
+
613
+ install_active_only_grads(sched, active_mask)
614
+
615
+ # The scheduler only learns from active chunks after warmup.
616
+ # During warmup it observes everything to build the similarity graph.
617
+ observed_for_scheduler = active_mask if step >= warmup_steps else torch.ones_like(active_mask)
618
+ sched.update_from_observed(
619
+ observed_mask=observed_for_scheduler,
620
+ true_masses=masses,
621
+ step=step,
622
+ warmup_steps=warmup_steps,
623
+ )
624
+
625
+ sched.prev_mask = active_mask.clone()
626
+ opt.step()
627
+
628
+ if benchmark_sync:
629
+ sync_device(device)
630
+ elapsed = time.perf_counter() - t0
631
+
632
+ val_loss = evaluate(model, corpus, batch_size, seed=12345)
633
+
634
+ if metric_rows:
635
+ avg_cos = sum(r["cos"] for r in metric_rows) / len(metric_rows)
636
+ avg_jacc = sum(r["jacc"] for r in metric_rows) / len(metric_rows)
637
+ avg_stable = sum(r["stable"] for r in metric_rows) / len(metric_rows)
638
+ else:
639
+ avg_cos = float("nan")
640
+ avg_jacc = float("nan")
641
+ avg_stable = float("nan")
642
+
643
+ return {
644
+ "val": val_loss,
645
+ "ms": 1000.0 * elapsed / steps,
646
+ "cos": avg_cos,
647
+ "jacc": avg_jacc,
648
+ "stable": avg_stable,
649
+ }
650
+
651
+
652
+ def build_runs(alphas: List[float]) -> List[Tuple[Scheduler, float, str]]:
653
+ runs: List[Tuple[Scheduler, float, str]] = [
654
+ ("dense", 1.0, "dense"),
655
+ ("ema_topk", 1.0, "ema_topk"),
656
+ ("knn_scheduler", 0.0, "knn"),
657
+ ("graph_scheduler", 0.0, "graph"),
658
+ ]
659
+
660
+ for a in alphas:
661
+ runs.append(("radar_knn", a, f"radar_knn_a{a:g}"))
662
+ for a in alphas:
663
+ runs.append(("radar_graph", a, f"radar_graph_a{a:g}"))
664
+
665
+ runs.append(("random", 0.0, "random"))
666
+ return runs
667
+
668
+
669
+ def main() -> None:
670
+ parser = argparse.ArgumentParser()
671
+ parser.add_argument("--steps", type=int, default=500)
672
+ parser.add_argument("--batch_size", type=int, default=8)
673
+ parser.add_argument("--block_size", type=int, default=128)
674
+ parser.add_argument("--n_layer", type=int, default=4)
675
+ parser.add_argument("--n_head", type=int, default=8)
676
+ parser.add_argument("--n_embd", type=int, default=512)
677
+ parser.add_argument("--chunk_size", type=int, default=64)
678
+ parser.add_argument("--active_fraction", type=float, default=0.10)
679
+ parser.add_argument("--warmup_steps", type=int, default=25)
680
+ parser.add_argument("--alphas", type=float, nargs="+", default=[0.25, 0.5, 0.75, 0.9])
681
+ parser.add_argument("--device", type=str, default="mps")
682
+ parser.add_argument("--benchmark_sync", action="store_true")
683
+ args = parser.parse_args()
684
+
685
+ runs = build_runs(args.alphas)
686
+
687
+ print("\nRadar scheduler diagnostic")
688
+ print(f"device={args.device} steps={args.steps} d={args.n_embd} chunks={args.chunk_size}")
689
+ print(f"active_fraction={args.active_fraction} warmup={args.warmup_steps}")
690
+ print(f"alphas={args.alphas}\n")
691
+ print(
692
+ f"{'run':>18s} | {'val':>8s} | {'ms/step':>8s} | "
693
+ f"{'grad_cos':>8s} | {'jacc':>8s} | {'stable':>8s}"
694
+ )
695
+ print("-" * 78)
696
+
697
+ for scheduler_name, alpha, label in runs:
698
+ result = run_experiment(
699
+ scheduler_name=scheduler_name,
700
+ alpha=alpha,
701
+ device=args.device,
702
+ steps=args.steps,
703
+ batch_size=args.batch_size,
704
+ block_size=args.block_size,
705
+ n_layer=args.n_layer,
706
+ n_head=args.n_head,
707
+ n_embd=args.n_embd,
708
+ chunk_size=args.chunk_size,
709
+ active_fraction=args.active_fraction,
710
+ warmup_steps=args.warmup_steps,
711
+ benchmark_sync=args.benchmark_sync,
712
+ )
713
+
714
+ print(
715
+ f"{label:>18s} | "
716
+ f"{result['val']:8.4f} | "
717
+ f"{result['ms']:8.2f} | "
718
+ f"{result['cos']:8.3f} | "
719
+ f"{result['jacc']:8.3f} | "
720
+ f"{result['stable']:8.3f}"
721
+ )
722
+
723
+
724
+ if __name__ == "__main__":
725
+ main()
experiments/sparse_transformer_v6.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sparse Transformer v6: stable predicted-magnitude masks, no dense refresh by default.
3
+
4
+ This prototype is designed to test the next hypothesis after the spiral/MLP runs:
5
+
6
+ The important gradient support is heavy-tailed and temporally stable enough
7
+ that we can select active parameter blocks from history, freeze the rest, and
8
+ still train a harder sequence model.
9
+
10
+ Key fixes versus v5
11
+ -------------------
12
+ 1. Harder model: a small causal Transformer language model.
13
+ 2. No periodic dense refresh by default: --warmup_steps 0.
14
+ 3. The selector only learns from blocks it actually observes/updates.
15
+ 4. Inactive Linear rows are truly frozen by MaskedAdam. This matters because
16
+ ordinary Adam can still move parameters with zero gradients through momentum.
17
+ 5. A true current-step oracle is included as an audit upper bound.
18
+ 6. Random masks are included as a control.
19
+
20
+ Important limitation
21
+ --------------------
22
+ This still calls loss.backward(), so PyTorch computes dense gradients. Those full
23
+ current gradients are used for audit metrics and for the oracle run only. The
24
+ practical predicted_magnitude selector is not allowed to update its statistics
25
+ from inactive full gradients.
26
+
27
+ Actual speedup would require structured partial backward/custom kernels.
28
+
29
+ Run
30
+ ---
31
+ python3 sparse_transformer_v6.py --quick
32
+ python3 sparse_transformer_v6.py --steps 1000 --active_fractions 0.10 0.05 0.02
33
+ python3 sparse_transformer_v6.py --text_path input.txt --steps 2000
34
+ """
35
+
36
+ from __future__ import annotations
37
+
38
+ import argparse
39
+ import math
40
+ import random
41
+ from typing import Dict, List, Literal, Optional, Tuple
42
+
43
+ import torch
44
+ torch.set_num_threads(1)
45
+ import torch.nn as nn
46
+ import torch.nn.functional as F
47
+
48
+ Policy = Literal["predicted_magnitude", "oracle_current", "random"]
49
+
50
+
51
+ def set_seed(seed: int) -> None:
52
+ random.seed(seed)
53
+ torch.manual_seed(seed)
54
+ if torch.cuda.is_available():
55
+ torch.cuda.manual_seed_all(seed)
56
+
57
+
58
+ def device() -> str:
59
+ return "cuda" if torch.cuda.is_available() else "cpu"
60
+
61
+
62
+ # -----------------------------
63
+ # Data
64
+ # -----------------------------
65
+
66
+ def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
67
+ rng = random.Random(seed)
68
+ names = ["ada", "turing", "grace", "lovelace", "noether", "shannon", "hopper", "gauss"]
69
+ verbs = ["builds", "tests", "traces", "compresses", "predicts", "routes", "writes", "measures"]
70
+ objects = ["signals", "gradients", "tokens", "circuits", "features", "masks", "errors", "states"]
71
+ adverbs = ["quietly", "boldly", "slowly", "quickly", "cleanly", "strangely", "carefully"]
72
+ clauses = [
73
+ "when the loss falls",
74
+ "after the mask shifts",
75
+ "before the model answers",
76
+ "while the signal drifts",
77
+ "if the pattern repeats",
78
+ "because the tail is noisy",
79
+ ]
80
+ symbols = ["alpha", "beta", "gamma", "delta", "omega", "sigma"]
81
+
82
+ lines: List[str] = []
83
+ for _ in range(n_sentences):
84
+ t = rng.randrange(6)
85
+ if t == 0:
86
+ line = f"{rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} {rng.choice(adverbs)}."
87
+ elif t == 1:
88
+ line = f"{rng.choice(clauses)}, {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)}."
89
+ elif t == 2:
90
+ a, b = rng.sample(symbols, 2)
91
+ line = f"rule {a}: {rng.choice(objects)} -> {rng.choice(objects)}; rule {b}: {rng.choice(objects)} -> {rng.choice(objects)}."
92
+ elif t == 3:
93
+ line = f"the {rng.choice(objects)} {rng.choice(verbs)} the {rng.choice(objects)} {rng.choice(adverbs)}."
94
+ elif t == 4:
95
+ seq = " ".join(rng.choice(symbols) for _ in range(rng.randint(2, 7)))
96
+ line = f"sequence {seq} ends when {rng.choice(names)} {rng.choice(verbs)}."
97
+ else:
98
+ line = f"if {rng.choice(objects)} rise then {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} else wait."
99
+ lines.append(line)
100
+ return "\n".join(lines) + "\n"
101
+
102
+
103
+ class CharCorpus:
104
+ def __init__(self, text: str, block_size: int, device: str):
105
+ chars = sorted(set(text))
106
+ self.stoi = {ch: i for i, ch in enumerate(chars)}
107
+ self.itos = {i: ch for ch, i in self.stoi.items()}
108
+ self.vocab_size = len(chars)
109
+ self.block_size = block_size
110
+ self.device = device
111
+
112
+ data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
113
+ split = int(0.9 * len(data))
114
+ self.train_data = data[:split]
115
+ self.val_data = data[split:]
116
+
117
+ def get_batch(self, split: str, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
118
+ data = self.train_data if split == "train" else self.val_data
119
+ max_start = len(data) - self.block_size - 1
120
+ if max_start <= 0:
121
+ raise ValueError("Corpus too small for block_size")
122
+ ix = torch.randint(max_start, (batch_size,))
123
+ x = torch.stack([data[i : i + self.block_size] for i in ix])
124
+ y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
125
+ return x.to(self.device), y.to(self.device)
126
+
127
+
128
+ def load_text(args: argparse.Namespace) -> str:
129
+ if args.text_path:
130
+ with open(args.text_path, "r", encoding="utf-8") as f:
131
+ return f.read()
132
+ return make_synthetic_corpus(args.synthetic_sentences, args.seed)
133
+
134
+
135
+ # -----------------------------
136
+ # Mini GPT
137
+ # -----------------------------
138
+
139
+ class CausalSelfAttention(nn.Module):
140
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
141
+ super().__init__()
142
+ assert n_embd % n_head == 0
143
+ self.n_head = n_head
144
+ self.head_dim = n_embd // n_head
145
+ self.c_attn = nn.Linear(n_embd, 3 * n_embd)
146
+ self.c_proj = nn.Linear(n_embd, n_embd)
147
+ self.dropout = nn.Dropout(dropout)
148
+ self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
149
+
150
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
151
+ B, T, C = x.shape
152
+ qkv = self.c_attn(x)
153
+ q, k, v = qkv.split(C, dim=2)
154
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
155
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
156
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
157
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
158
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
159
+ att = F.softmax(att, dim=-1)
160
+ att = self.dropout(att)
161
+ y = att @ v
162
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
163
+ return self.c_proj(y)
164
+
165
+
166
+ class FeedForward(nn.Module):
167
+ def __init__(self, n_embd: int, dropout: float):
168
+ super().__init__()
169
+ self.c_fc = nn.Linear(n_embd, 4 * n_embd)
170
+ self.c_proj = nn.Linear(4 * n_embd, n_embd)
171
+ self.dropout = nn.Dropout(dropout)
172
+
173
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
174
+ return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
175
+
176
+
177
+ class Block(nn.Module):
178
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
179
+ super().__init__()
180
+ self.ln1 = nn.LayerNorm(n_embd)
181
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
182
+ self.ln2 = nn.LayerNorm(n_embd)
183
+ self.mlp = FeedForward(n_embd, dropout)
184
+
185
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
186
+ x = x + self.attn(self.ln1(x))
187
+ x = x + self.mlp(self.ln2(x))
188
+ return x
189
+
190
+
191
+ class MiniGPT(nn.Module):
192
+ def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float):
193
+ super().__init__()
194
+ self.block_size = block_size
195
+ self.tok_emb = nn.Embedding(vocab_size, n_embd)
196
+ self.pos_emb = nn.Embedding(block_size, n_embd)
197
+ self.drop = nn.Dropout(dropout)
198
+ self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
199
+ self.ln_f = nn.LayerNorm(n_embd)
200
+ self.lm_head = nn.Linear(n_embd, vocab_size)
201
+
202
+ def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
203
+ B, T = idx.shape
204
+ pos = torch.arange(T, device=idx.device)
205
+ x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
206
+ x = self.drop(x)
207
+ x = self.blocks(x)
208
+ x = self.ln_f(x)
209
+ logits = self.lm_head(x)
210
+ loss = None
211
+ if targets is not None:
212
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
213
+ return logits, loss
214
+
215
+
216
+ def named_linear_modules(model: nn.Module) -> List[Tuple[str, nn.Linear]]:
217
+ return [(name, m) for name, m in model.named_modules() if isinstance(m, nn.Linear)]
218
+
219
+
220
+ # -----------------------------
221
+ # Mask selector
222
+ # -----------------------------
223
+
224
+ class RowMasker:
225
+ def __init__(
226
+ self,
227
+ model: nn.Module,
228
+ policy: Policy,
229
+ active_fraction: float,
230
+ explore_fraction: float,
231
+ mass_beta: float,
232
+ unobserved_decay: float,
233
+ warmup_steps: int,
234
+ device: str,
235
+ ):
236
+ self.model = model
237
+ self.policy = policy
238
+ self.active_fraction = active_fraction
239
+ self.explore_fraction = explore_fraction
240
+ self.mass_beta = mass_beta
241
+ self.unobserved_decay = unobserved_decay
242
+ self.warmup_steps = warmup_steps
243
+ self.device = device
244
+
245
+ self.linear_modules = [m for _, m in named_linear_modules(model)]
246
+ self.module_to_ids: Dict[nn.Linear, torch.Tensor] = {}
247
+ ids = []
248
+ offset = 0
249
+ for m in self.linear_modules:
250
+ n = m.weight.shape[0]
251
+ block_ids = torch.arange(offset, offset + n, device=device)
252
+ self.module_to_ids[m] = block_ids
253
+ ids.append(block_ids)
254
+ offset += n
255
+ self.n_blocks = offset
256
+
257
+ self.predicted_mass = torch.ones(self.n_blocks, device=device)
258
+ self.prev_active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device)
259
+ self.active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device)
260
+ self.row_masks: Dict[nn.Linear, torch.Tensor] = {m: torch.zeros(m.weight.shape[0], dtype=torch.bool, device=device) for m in self.linear_modules}
261
+
262
+ def _topk_mask(self, values: torch.Tensor, fraction: float) -> torch.Tensor:
263
+ k = max(1, int(fraction * values.numel()))
264
+ mask = torch.zeros_like(values, dtype=torch.bool)
265
+ mask[torch.topk(values, k=k).indices] = True
266
+ return mask
267
+
268
+ @staticmethod
269
+ def _jaccard(a: torch.Tensor, b: torch.Tensor) -> float:
270
+ inter = (a & b).sum().float()
271
+ union = (a | b).sum().float()
272
+ return float((inter / torch.clamp(union, min=1.0)).item())
273
+
274
+ def _set_active(self, active: torch.Tensor) -> None:
275
+ self.active = active
276
+ self.row_masks = {}
277
+ for m, ids in self.module_to_ids.items():
278
+ self.row_masks[m] = active[ids]
279
+
280
+ def choose_pre_backward(self, step: int) -> None:
281
+ if step < self.warmup_steps:
282
+ self._set_active(torch.ones(self.n_blocks, dtype=torch.bool, device=self.device))
283
+ return
284
+
285
+ if self.policy == "oracle_current":
286
+ # Cannot select until after current gradients are known.
287
+ self._set_active(torch.zeros(self.n_blocks, dtype=torch.bool, device=self.device))
288
+ return
289
+
290
+ k_total = max(1, int(self.active_fraction * self.n_blocks))
291
+
292
+ if self.policy == "random":
293
+ active = torch.zeros(self.n_blocks, dtype=torch.bool, device=self.device)
294
+ active[torch.randperm(self.n_blocks, device=self.device)[:k_total]] = True
295
+ self._set_active(active)
296
+ return
297
+
298
+ if self.policy != "predicted_magnitude":
299
+ raise ValueError(f"Unknown policy: {self.policy}")
300
+
301
+ k_explore = min(k_total, max(0, int(self.explore_fraction * k_total)))
302
+ k_exploit = k_total - k_explore
303
+ active = torch.zeros(self.n_blocks, dtype=torch.bool, device=self.device)
304
+
305
+ scores = self.predicted_mass + 1e-8 * torch.rand_like(self.predicted_mass)
306
+ if k_exploit > 0:
307
+ active[torch.topk(scores, k=k_exploit).indices] = True
308
+ if k_explore > 0:
309
+ remaining = torch.nonzero(~active, as_tuple=False).flatten()
310
+ active[remaining[torch.randperm(remaining.numel(), device=self.device)[:k_explore]]] = True
311
+ self._set_active(active)
312
+
313
+ @torch.no_grad()
314
+ def current_gradient_mass(self) -> torch.Tensor:
315
+ mass = torch.zeros(self.n_blocks, device=self.device)
316
+ for m, ids in self.module_to_ids.items():
317
+ if m.weight.grad is None:
318
+ continue
319
+ row_sq = m.weight.grad.square().sum(dim=1)
320
+ if m.bias is not None and m.bias.grad is not None:
321
+ row_sq = row_sq + m.bias.grad.square()
322
+ mass[ids] = torch.sqrt(row_sq + 1e-30)
323
+ return mass
324
+
325
+ @torch.no_grad()
326
+ def audit_and_update(self, step: int) -> Dict[str, float]:
327
+ mass = self.current_gradient_mass()
328
+
329
+ if step < self.warmup_steps:
330
+ active = torch.ones(self.n_blocks, dtype=torch.bool, device=self.device)
331
+ self._set_active(active)
332
+ elif self.policy == "oracle_current":
333
+ active = self._topk_mask(mass, self.active_fraction)
334
+ self._set_active(active)
335
+ else:
336
+ active = self.active
337
+
338
+ true_sq = mass.square().sum()
339
+ approx_sq = mass[active].square().sum()
340
+ cosine = float((torch.sqrt(approx_sq + 1e-30) / torch.sqrt(true_sq + 1e-30)).item())
341
+ # With zero inactive blocks and active blocks using true gradient, cosine == norm ratio.
342
+ norm_ratio = cosine
343
+
344
+ oracle_mask = self._topk_mask(mass, self.active_fraction)
345
+ jacc = self._jaccard(active, oracle_mask)
346
+ stability = self._jaccard(active, self.prev_active)
347
+ self.prev_active = active.clone()
348
+
349
+ k20 = max(1, int(0.2 * self.n_blocks))
350
+ sorted_mass = torch.sort(mass, descending=True).values
351
+ top20_mass = float((sorted_mass[:k20].sum() / (sorted_mass.sum() + 1e-12)).item())
352
+
353
+ # Strict rule: do not update stats from inactive full gradients.
354
+ self.predicted_mass.mul_(self.unobserved_decay)
355
+ observed = active
356
+ self.predicted_mass[observed] = (
357
+ self.mass_beta * self.predicted_mass[observed]
358
+ + (1.0 - self.mass_beta) * mass[observed]
359
+ )
360
+
361
+ return {
362
+ "cosine": cosine,
363
+ "norm_ratio": norm_ratio,
364
+ "top20_mass": top20_mass,
365
+ "jacc_oracle": jacc,
366
+ "stability": stability,
367
+ "active_fraction_real": float(active.float().mean().item()),
368
+ }
369
+
370
+ def row_mask_for(self, module: nn.Linear) -> Optional[torch.Tensor]:
371
+ return self.row_masks.get(module)
372
+
373
+
374
+ # -----------------------------
375
+ # Masked Adam
376
+ # -----------------------------
377
+
378
+ class MaskedAdam:
379
+ def __init__(self, model: nn.Module, masker: Optional[RowMasker], lr: float, betas=(0.9, 0.95), eps=1e-8, weight_decay=0.0):
380
+ self.model = model
381
+ self.masker = masker
382
+ self.lr = lr
383
+ self.beta1, self.beta2 = betas
384
+ self.eps = eps
385
+ self.weight_decay = weight_decay
386
+ self.state: Dict[nn.Parameter, Dict[str, torch.Tensor]] = {}
387
+ self.linear_param: Dict[nn.Parameter, Tuple[nn.Linear, str]] = {}
388
+ for _, m in named_linear_modules(model):
389
+ self.linear_param[m.weight] = (m, "weight")
390
+ if m.bias is not None:
391
+ self.linear_param[m.bias] = (m, "bias")
392
+
393
+ def zero_grad(self) -> None:
394
+ for p in self.model.parameters():
395
+ p.grad = None
396
+
397
+ @torch.no_grad()
398
+ def step(self) -> None:
399
+ for p in self.model.parameters():
400
+ if p.grad is None:
401
+ continue
402
+ if p not in self.state:
403
+ self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
404
+ m = self.state[p]["m"]
405
+ v = self.state[p]["v"]
406
+ g = p.grad
407
+ if self.weight_decay:
408
+ g = g.add(p, alpha=self.weight_decay)
409
+
410
+ row_mask = None
411
+ if self.masker is not None and p in self.linear_param:
412
+ module, kind = self.linear_param[p]
413
+ base = self.masker.row_mask_for(module)
414
+ if base is not None:
415
+ row_mask = base.view(-1, *([1] * (p.ndim - 1))) if kind == "weight" else base
416
+
417
+ if row_mask is None:
418
+ m.mul_(self.beta1).add_(g, alpha=1.0 - self.beta1)
419
+ v.mul_(self.beta2).addcmul_(g, g, value=1.0 - self.beta2)
420
+ p.add_(m / (torch.sqrt(v) + self.eps), alpha=-self.lr)
421
+ else:
422
+ mask = row_mask.expand_as(p)
423
+ if not bool(mask.any().item()):
424
+ continue
425
+ new_m = self.beta1 * m + (1.0 - self.beta1) * g
426
+ new_v = self.beta2 * v + (1.0 - self.beta2) * g * g
427
+ m[mask] = new_m[mask]
428
+ v[mask] = new_v[mask]
429
+ update = m / (torch.sqrt(v) + self.eps)
430
+ p[mask] = p[mask] - self.lr * update[mask]
431
+
432
+
433
+ # -----------------------------
434
+ # Training
435
+ # -----------------------------
436
+
437
+ @torch.no_grad()
438
+ def estimate_loss(model: nn.Module, corpus: CharCorpus, batch_size: int, eval_iters: int) -> Dict[str, float]:
439
+ model.eval()
440
+ out = {}
441
+ for split in ["train", "val"]:
442
+ losses = []
443
+ for _ in range(eval_iters):
444
+ x, y = corpus.get_batch(split, batch_size)
445
+ _, loss = model(x, y)
446
+ losses.append(float(loss.item()))
447
+ out[split] = sum(losses) / len(losses)
448
+ model.train()
449
+ return out
450
+
451
+
452
+ def train_run(corpus: CharCorpus, args: argparse.Namespace, policy: Optional[Policy], active_fraction: float, seed_offset: int) -> Dict[str, float | str]:
453
+ set_seed(args.seed + seed_offset)
454
+ dev = corpus.device
455
+ model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev)
456
+
457
+ masker = None
458
+ if policy is not None:
459
+ masker = RowMasker(
460
+ model=model,
461
+ policy=policy,
462
+ active_fraction=active_fraction,
463
+ explore_fraction=args.explore_fraction,
464
+ mass_beta=args.mass_beta,
465
+ unobserved_decay=args.unobserved_decay,
466
+ warmup_steps=args.warmup_steps,
467
+ device=dev,
468
+ )
469
+ opt = MaskedAdam(model, masker, lr=args.lr, weight_decay=args.weight_decay)
470
+
471
+ sums = {"cosine": 0.0, "norm_ratio": 0.0, "top20_mass": 0.0, "jacc_oracle": 0.0, "stability": 0.0, "active_fraction_real": 0.0}
472
+ count = 0
473
+
474
+ for step in range(args.steps):
475
+ x, y = corpus.get_batch("train", args.batch_size)
476
+ if masker is not None:
477
+ masker.choose_pre_backward(step)
478
+ _, loss = model(x, y)
479
+ opt.zero_grad()
480
+ loss.backward()
481
+ if masker is not None:
482
+ metrics = masker.audit_and_update(step)
483
+ if step >= args.warmup_steps:
484
+ for k in sums:
485
+ sums[k] += metrics[k]
486
+ count += 1
487
+ opt.step()
488
+
489
+ if args.verbose and (step % args.eval_interval == 0 or step == args.steps - 1):
490
+ losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters)
491
+ name = "dense" if policy is None else policy
492
+ print(f"{name:20s} step={step:5d} train={losses['train']:.4f} val={losses['val']:.4f}")
493
+
494
+ losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters)
495
+ row: Dict[str, float | str] = {
496
+ "run": "dense_baseline" if policy is None else policy,
497
+ "target_active": 1.0 if policy is None else active_fraction,
498
+ "train_loss": losses["train"],
499
+ "val_loss": losses["val"],
500
+ }
501
+ if masker is None or count == 0:
502
+ row.update({"cosine": float("nan"), "norm_ratio": float("nan"), "top20_mass": float("nan"), "jacc_oracle": float("nan"), "stability": float("nan"), "active_fraction_real": 1.0})
503
+ else:
504
+ for k, v in sums.items():
505
+ row[k] = v / count
506
+ return row
507
+
508
+
509
+ def print_summary(rows: List[Dict[str, float | str]]) -> None:
510
+ print("\nSummary")
511
+ header = f"{'run':>22s} {'target':>7s} {'actual':>7s} {'val':>8s} {'train':>8s} {'cos':>7s} {'top20':>7s} {'jacc':>7s} {'stable':>7s}"
512
+ print(header)
513
+ print("-" * len(header))
514
+ for r in rows:
515
+ print(
516
+ f"{str(r['run']):>22s} "
517
+ f"{float(r['target_active']):7.3f} "
518
+ f"{float(r['active_fraction_real']):7.3f} "
519
+ f"{float(r['val_loss']):8.4f} "
520
+ f"{float(r['train_loss']):8.4f} "
521
+ f"{float(r['cosine']):7.3f} "
522
+ f"{float(r['top20_mass']):7.3f} "
523
+ f"{float(r['jacc_oracle']):7.3f} "
524
+ f"{float(r['stability']):7.3f}"
525
+ )
526
+
527
+
528
+ def parse_args() -> argparse.Namespace:
529
+ p = argparse.ArgumentParser()
530
+ p.add_argument("--text_path", type=str, default=None)
531
+ p.add_argument("--synthetic_sentences", type=int, default=12000)
532
+ p.add_argument("--steps", type=int, default=1000)
533
+ p.add_argument("--quick", action="store_true")
534
+ p.add_argument("--batch_size", type=int, default=32)
535
+ p.add_argument("--block_size", type=int, default=64)
536
+ p.add_argument("--n_layer", type=int, default=2)
537
+ p.add_argument("--n_head", type=int, default=4)
538
+ p.add_argument("--n_embd", type=int, default=64)
539
+ p.add_argument("--dropout", type=float, default=0.0)
540
+ p.add_argument("--lr", type=float, default=3e-4)
541
+ p.add_argument("--weight_decay", type=float, default=0.0)
542
+ p.add_argument("--active_fractions", type=float, nargs="+", default=[0.10, 0.05, 0.02])
543
+ p.add_argument("--explore_fraction", type=float, default=0.10)
544
+ p.add_argument("--mass_beta", type=float, default=0.95)
545
+ p.add_argument("--unobserved_decay", type=float, default=0.999)
546
+ p.add_argument("--warmup_steps", type=int, default=0)
547
+ p.add_argument("--eval_interval", type=int, default=200)
548
+ p.add_argument("--eval_iters", type=int, default=20)
549
+ p.add_argument("--seed", type=int, default=7)
550
+ p.add_argument("--verbose", action="store_true")
551
+ return p.parse_args()
552
+
553
+
554
+ def main() -> None:
555
+ args = parse_args()
556
+ if args.quick:
557
+ args.steps = 60
558
+ args.eval_iters = 3
559
+ args.batch_size = 16
560
+ args.block_size = 32
561
+ args.n_layer = 1
562
+ args.n_embd = 32
563
+ args.n_head = 4
564
+ args.synthetic_sentences = 2000
565
+ args.active_fractions = [0.10, 0.02]
566
+
567
+ set_seed(args.seed)
568
+ dev = device()
569
+ print(f"device={dev}")
570
+ corpus = CharCorpus(load_text(args), args.block_size, dev)
571
+ print(f"vocab_size={corpus.vocab_size} train_tokens={len(corpus.train_data)} val_tokens={len(corpus.val_data)}")
572
+ print(f"warmup_steps={args.warmup_steps} explore_fraction={args.explore_fraction}")
573
+
574
+ rows: List[Dict[str, float | str]] = []
575
+ print("\nRunning dense baseline")
576
+ rows.append(train_run(corpus, args, policy=None, active_fraction=1.0, seed_offset=0))
577
+
578
+ seed_offset = 100
579
+ for af in args.active_fractions:
580
+ for policy in ["oracle_current", "predicted_magnitude", "random"]:
581
+ print(f"\nRunning policy={policy}, active_fraction={af:.3f}")
582
+ rows.append(train_run(corpus, args, policy=policy, active_fraction=af, seed_offset=seed_offset))
583
+ seed_offset += 1
584
+
585
+ print_summary(rows)
586
+
587
+ print("\nNotes")
588
+ print(" oracle_current uses the current full gradient to choose rows; it is an upper bound, not a practical selector.")
589
+ print(" predicted_magnitude chooses from EMA mass only, plus a small random exploration budget.")
590
+ print(" EMA mass is updated only for active/observed rows, not all rows.")
591
+ print(" inactive Linear rows are frozen by MaskedAdam, including Adam state; zero grad alone is not enough.")
592
+ print(" dense gradients are still computed for audit, so this is not a wall-clock speed benchmark yet.")
593
+
594
+
595
+ if __name__ == "__main__":
596
+ main()
experiments/sparse_transformer_v7.py ADDED
@@ -0,0 +1,780 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sparse Transformer v7: discovery stress tests for stable gradient-support masks.
3
+
4
+ This version follows the v6 result:
5
+
6
+ oracle_current works far better than random, so useful sparse support exists;
7
+ predicted_magnitude without warmup does not reliably discover that support.
8
+
9
+ v7 focuses on discovery mechanisms:
10
+
11
+ 1. predicted_magnitude
12
+ Exploit rows with the largest EMA-observed gradient mass.
13
+
14
+ 2. ucb_magnitude
15
+ A bandit-style selector: EMA mass + an uncertainty bonus for under-observed rows.
16
+ This is meant to discover useful rows without dense refresh.
17
+
18
+ First observation initializes EMA scale immediately.
19
+
20
+ 3. stale_current
21
+ A renamed diagnostic control: use the previous full-gradient mass. It is not
22
+ practical because it relies on dense audit gradients, but it tells us whether
23
+ one-step lag is too noisy.
24
+
25
+ 4. oracle_current
26
+ True current top-k by dense gradient mass. Upper bound only.
27
+
28
+ 5. random
29
+ Control.
30
+
31
+ Important limitation
32
+ --------------------
33
+ This still calls loss.backward(), so PyTorch computes dense gradients. Dense
34
+ current gradients are used for audit metrics and for oracle/stale controls.
35
+ The practical selectors only update their EMA statistics from active rows.
36
+ Actual speedup would require structured partial backward/custom kernels.
37
+
38
+ Example runs
39
+ ------------
40
+ Smoke test:
41
+ python3 sparse_transformer_v7.py --quick
42
+
43
+ No-warmup discovery test:
44
+ python3 sparse_transformer_v7.py --steps 1000 \
45
+ --active_fractions 0.10 0.05 0.02 \
46
+ --policies predicted_magnitude ucb_magnitude oracle_current random \
47
+ --warmup_steps_list 0 5 50 --explore_fractions 0.10 0.30
48
+
49
+ Warm-start separation test:
50
+ python3 sparse_transformer_v7.py --steps 1000 \
51
+ --active_fractions 0.10 0.05 0.02 \
52
+ --policies predicted_magnitude ucb_magnitude oracle_current random \
53
+ --warmup_steps_list 0 5 50 200 --explore_fractions 0.10
54
+ """
55
+
56
+ from __future__ import annotations
57
+
58
+ import argparse
59
+ import math
60
+ import random
61
+ from typing import Dict, List, Literal, Optional, Tuple
62
+
63
+ import torch
64
+
65
+ torch.set_num_threads(1)
66
+ import torch.nn as nn
67
+ import torch.nn.functional as F
68
+
69
+ Policy = Literal["predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"]
70
+
71
+
72
+ def set_seed(seed: int) -> None:
73
+ random.seed(seed)
74
+ torch.manual_seed(seed)
75
+ if torch.cuda.is_available():
76
+ torch.cuda.manual_seed_all(seed)
77
+
78
+
79
+ def device() -> str:
80
+ return "cuda" if torch.cuda.is_available() else "cpu"
81
+
82
+
83
+ # -----------------------------
84
+ # Data
85
+ # -----------------------------
86
+
87
+ def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
88
+ rng = random.Random(seed)
89
+ names = ["ada", "turing", "grace", "lovelace", "noether", "shannon", "hopper", "gauss"]
90
+ verbs = ["builds", "tests", "traces", "compresses", "predicts", "routes", "writes", "measures"]
91
+ objects = ["signals", "gradients", "tokens", "circuits", "features", "masks", "errors", "states"]
92
+ adverbs = ["quietly", "boldly", "slowly", "quickly", "cleanly", "strangely", "carefully"]
93
+ clauses = [
94
+ "when the loss falls",
95
+ "after the mask shifts",
96
+ "before the model answers",
97
+ "while the signal drifts",
98
+ "if the pattern repeats",
99
+ "because the tail is noisy",
100
+ ]
101
+ symbols = ["alpha", "beta", "gamma", "delta", "omega", "sigma"]
102
+
103
+ lines: List[str] = []
104
+ for _ in range(n_sentences):
105
+ t = rng.randrange(6)
106
+ if t == 0:
107
+ line = f"{rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} {rng.choice(adverbs)}."
108
+ elif t == 1:
109
+ line = f"{rng.choice(clauses)}, {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)}."
110
+ elif t == 2:
111
+ a, b = rng.sample(symbols, 2)
112
+ line = f"rule {a}: {rng.choice(objects)} -> {rng.choice(objects)}; rule {b}: {rng.choice(objects)} -> {rng.choice(objects)}."
113
+ elif t == 3:
114
+ line = f"the {rng.choice(objects)} {rng.choice(verbs)} the {rng.choice(objects)} {rng.choice(adverbs)}."
115
+ elif t == 4:
116
+ seq = " ".join(rng.choice(symbols) for _ in range(rng.randint(2, 7)))
117
+ line = f"sequence {seq} ends when {rng.choice(names)} {rng.choice(verbs)}."
118
+ else:
119
+ line = f"if {rng.choice(objects)} rise then {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} else wait."
120
+ lines.append(line)
121
+ return "\n".join(lines) + "\n"
122
+
123
+
124
+ class CharCorpus:
125
+ def __init__(self, text: str, block_size: int, device: str):
126
+ chars = sorted(set(text))
127
+ self.stoi = {ch: i for i, ch in enumerate(chars)}
128
+ self.itos = {i: ch for ch, i in self.stoi.items()}
129
+ self.vocab_size = len(chars)
130
+ self.block_size = block_size
131
+ self.device = device
132
+
133
+ data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
134
+ split = int(0.9 * len(data))
135
+ self.train_data = data[:split]
136
+ self.val_data = data[split:]
137
+
138
+ def get_batch(self, split: str, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
139
+ data = self.train_data if split == "train" else self.val_data
140
+ max_start = len(data) - self.block_size - 1
141
+ if max_start <= 0:
142
+ raise ValueError("Corpus too small for block_size")
143
+ ix = torch.randint(max_start, (batch_size,))
144
+ x = torch.stack([data[i : i + self.block_size] for i in ix])
145
+ y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
146
+ return x.to(self.device), y.to(self.device)
147
+
148
+
149
+ def load_text(args: argparse.Namespace) -> str:
150
+ if args.text_path:
151
+ with open(args.text_path, "r", encoding="utf-8") as f:
152
+ return f.read()
153
+ return make_synthetic_corpus(args.synthetic_sentences, args.seed)
154
+
155
+
156
+ # -----------------------------
157
+ # Mini GPT
158
+ # -----------------------------
159
+
160
+ class CausalSelfAttention(nn.Module):
161
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
162
+ super().__init__()
163
+ assert n_embd % n_head == 0
164
+ self.n_head = n_head
165
+ self.head_dim = n_embd // n_head
166
+ self.c_attn = nn.Linear(n_embd, 3 * n_embd)
167
+ self.c_proj = nn.Linear(n_embd, n_embd)
168
+ self.dropout = nn.Dropout(dropout)
169
+ self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
170
+
171
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
172
+ B, T, C = x.shape
173
+ qkv = self.c_attn(x)
174
+ q, k, v = qkv.split(C, dim=2)
175
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
176
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
177
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
178
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
179
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
180
+ att = F.softmax(att, dim=-1)
181
+ att = self.dropout(att)
182
+ y = att @ v
183
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
184
+ return self.c_proj(y)
185
+
186
+
187
+ class FeedForward(nn.Module):
188
+ def __init__(self, n_embd: int, dropout: float):
189
+ super().__init__()
190
+ self.c_fc = nn.Linear(n_embd, 4 * n_embd)
191
+ self.c_proj = nn.Linear(4 * n_embd, n_embd)
192
+ self.dropout = nn.Dropout(dropout)
193
+
194
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
195
+ return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
196
+
197
+
198
+ class Block(nn.Module):
199
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
200
+ super().__init__()
201
+ self.ln1 = nn.LayerNorm(n_embd)
202
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
203
+ self.ln2 = nn.LayerNorm(n_embd)
204
+ self.mlp = FeedForward(n_embd, dropout)
205
+
206
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
207
+ x = x + self.attn(self.ln1(x))
208
+ x = x + self.mlp(self.ln2(x))
209
+ return x
210
+
211
+
212
+ class MiniGPT(nn.Module):
213
+ def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float):
214
+ super().__init__()
215
+ self.block_size = block_size
216
+ self.tok_emb = nn.Embedding(vocab_size, n_embd)
217
+ self.pos_emb = nn.Embedding(block_size, n_embd)
218
+ self.drop = nn.Dropout(dropout)
219
+ self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
220
+ self.ln_f = nn.LayerNorm(n_embd)
221
+ self.lm_head = nn.Linear(n_embd, vocab_size)
222
+
223
+ def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
224
+ B, T = idx.shape
225
+ pos = torch.arange(T, device=idx.device)
226
+ x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
227
+ x = self.drop(x)
228
+ x = self.blocks(x)
229
+ x = self.ln_f(x)
230
+ logits = self.lm_head(x)
231
+ loss = None
232
+ if targets is not None:
233
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
234
+ return logits, loss
235
+
236
+
237
+ def named_linear_modules(model: nn.Module) -> List[Tuple[str, nn.Linear]]:
238
+ return [(name, m) for name, m in model.named_modules() if isinstance(m, nn.Linear)]
239
+
240
+
241
+ def parameter_fractions(model: nn.Module) -> Tuple[int, int, float]:
242
+ total = sum(p.numel() for p in model.parameters())
243
+ linear = 0
244
+ for _, m in named_linear_modules(model):
245
+ linear += m.weight.numel()
246
+ if m.bias is not None:
247
+ linear += m.bias.numel()
248
+ return total, linear, linear / max(1, total)
249
+
250
+
251
+ # -----------------------------
252
+ # Mask selector
253
+ # -----------------------------
254
+
255
+ class RowMasker:
256
+ def __init__(
257
+ self,
258
+ model: nn.Module,
259
+ policy: Policy,
260
+ active_fraction: float,
261
+ explore_fraction: float,
262
+ mass_beta: float,
263
+ unobserved_decay: float,
264
+ warmup_steps: int,
265
+ ucb_alpha: float,
266
+ mass_init: float,
267
+ device: str,
268
+ ):
269
+ self.model = model
270
+ self.policy = policy
271
+ self.active_fraction = active_fraction
272
+ self.explore_fraction = explore_fraction
273
+ self.mass_beta = mass_beta
274
+ self.unobserved_decay = unobserved_decay
275
+ self.warmup_steps = warmup_steps
276
+ self.ucb_alpha = ucb_alpha
277
+ self.mass_init = mass_init
278
+ self.device = device
279
+ self.step_index = 0
280
+
281
+ self.linear_modules = [m for _, m in named_linear_modules(model)]
282
+ self.module_to_ids: Dict[nn.Linear, torch.Tensor] = {}
283
+ ids = []
284
+ offset = 0
285
+ for m in self.linear_modules:
286
+ n = m.weight.shape[0]
287
+ block_ids = torch.arange(offset, offset + n, device=device)
288
+ self.module_to_ids[m] = block_ids
289
+ ids.append(block_ids)
290
+ offset += n
291
+ self.n_blocks = offset
292
+
293
+ self.predicted_mass = torch.full((self.n_blocks,), mass_init, device=device)
294
+ self.last_full_mass = torch.full((self.n_blocks,), mass_init, device=device)
295
+ self.observed_count = torch.zeros(self.n_blocks, device=device)
296
+ self.global_mass_ema = torch.tensor(max(mass_init, 1e-6), device=device)
297
+
298
+ self.prev_active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device)
299
+ self.active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device)
300
+ self.row_masks: Dict[nn.Linear, torch.Tensor] = {
301
+ m: torch.zeros(m.weight.shape[0], dtype=torch.bool, device=device) for m in self.linear_modules
302
+ }
303
+
304
+ def _topk_mask(self, values: torch.Tensor, fraction: float) -> torch.Tensor:
305
+ k = max(1, int(fraction * values.numel()))
306
+ mask = torch.zeros_like(values, dtype=torch.bool)
307
+ # Tie-breaking noise matters when many rows have identical initial scores.
308
+ noisy = values + 1e-9 * torch.rand_like(values)
309
+ mask[torch.topk(noisy, k=k).indices] = True
310
+ return mask
311
+
312
+ @staticmethod
313
+ def _jaccard(a: torch.Tensor, b: torch.Tensor) -> float:
314
+ inter = (a & b).sum().float()
315
+ union = (a | b).sum().float()
316
+ return float((inter / torch.clamp(union, min=1.0)).item())
317
+
318
+ def _set_active(self, active: torch.Tensor) -> None:
319
+ self.active = active
320
+ self.row_masks = {}
321
+ for m, ids in self.module_to_ids.items():
322
+ self.row_masks[m] = active[ids]
323
+
324
+ def _sample_exploit_explore(self, scores: torch.Tensor) -> torch.Tensor:
325
+ n = self.n_blocks
326
+ k_total = max(1, int(self.active_fraction * n))
327
+ k_explore = min(k_total, max(0, int(self.explore_fraction * k_total)))
328
+ k_exploit = k_total - k_explore
329
+ active = torch.zeros(n, dtype=torch.bool, device=self.device)
330
+
331
+ if k_exploit > 0:
332
+ active[torch.topk(scores + 1e-9 * torch.rand_like(scores), k=k_exploit).indices] = True
333
+ if k_explore > 0:
334
+ remaining = torch.nonzero(~active, as_tuple=False).flatten()
335
+ pick = remaining[torch.randperm(remaining.numel(), device=self.device)[:k_explore]]
336
+ active[pick] = True
337
+ return active
338
+
339
+ def choose_pre_backward(self, step: int) -> None:
340
+ self.step_index = step
341
+ if step < self.warmup_steps:
342
+ self._set_active(torch.ones(self.n_blocks, dtype=torch.bool, device=self.device))
343
+ return
344
+
345
+ if self.policy == "oracle_current":
346
+ # Cannot select until after current gradients are known.
347
+ self._set_active(torch.zeros(self.n_blocks, dtype=torch.bool, device=self.device))
348
+ return
349
+
350
+ if self.policy == "random":
351
+ self._set_active(self._sample_exploit_explore(torch.rand(self.n_blocks, device=self.device)))
352
+ return
353
+
354
+ if self.policy == "stale_current":
355
+ self._set_active(self._topk_mask(self.last_full_mass, self.active_fraction))
356
+ return
357
+
358
+ if self.policy == "predicted_magnitude":
359
+ self._set_active(self._sample_exploit_explore(self.predicted_mass))
360
+ return
361
+
362
+ if self.policy == "ucb_magnitude":
363
+ t = max(1, step - self.warmup_steps + 1)
364
+ log_term = torch.log(torch.tensor(float(t + 2), device=self.device))
365
+ bonus_scale = torch.clamp(self.global_mass_ema, min=1e-8)
366
+ bonus = self.ucb_alpha * bonus_scale * torch.sqrt(log_term / (self.observed_count + 1.0))
367
+ scores = self.predicted_mass + bonus
368
+ self._set_active(self._sample_exploit_explore(scores))
369
+ return
370
+
371
+ raise ValueError(f"Unknown policy: {self.policy}")
372
+
373
+ @torch.no_grad()
374
+ def current_gradient_mass(self) -> torch.Tensor:
375
+ mass = torch.zeros(self.n_blocks, device=self.device)
376
+ for m, ids in self.module_to_ids.items():
377
+ if m.weight.grad is None:
378
+ continue
379
+ row_sq = m.weight.grad.square().sum(dim=1)
380
+ if m.bias is not None and m.bias.grad is not None:
381
+ row_sq = row_sq + m.bias.grad.square()
382
+ mass[ids] = torch.sqrt(row_sq + 1e-30)
383
+ return mass
384
+
385
+ @torch.no_grad()
386
+ def audit_and_update(self, step: int) -> Dict[str, float]:
387
+ mass = self.current_gradient_mass()
388
+
389
+ if step < self.warmup_steps:
390
+ active = torch.ones(self.n_blocks, dtype=torch.bool, device=self.device)
391
+ self._set_active(active)
392
+ elif self.policy == "oracle_current":
393
+ active = self._topk_mask(mass, self.active_fraction)
394
+ self._set_active(active)
395
+ else:
396
+ active = self.active
397
+
398
+ true_sq = mass.square().sum()
399
+ approx_sq = mass[active].square().sum()
400
+ cosine = float((torch.sqrt(approx_sq + 1e-30) / torch.sqrt(true_sq + 1e-30)).item())
401
+ norm_ratio = cosine
402
+
403
+ oracle_mask = self._topk_mask(mass, self.active_fraction)
404
+ jacc = self._jaccard(active, oracle_mask)
405
+ stability = self._jaccard(active, self.prev_active)
406
+ self.prev_active = active.clone()
407
+
408
+ k20 = max(1, int(0.2 * self.n_blocks))
409
+ sorted_mass = torch.sort(mass, descending=True).values
410
+ top20_mass = float((sorted_mass[:k20].sum() / (sorted_mass.sum() + 1e-12)).item())
411
+
412
+ new_active = active & (self.observed_count == 0)
413
+
414
+ # Strict rule for practical policies: update stats only from active rows.
415
+ # oracle_current and stale_current also update only active rows for consistency;
416
+ # stale_current separately records last_full_mass as a diagnostic signal.
417
+ self.predicted_mass.mul_(self.unobserved_decay)
418
+ observed = active
419
+ if bool(observed.any().item()):
420
+ obs_mass = mass[observed]
421
+ first_seen = self.observed_count[observed] == 0
422
+ ema_mass = self.mass_beta * self.predicted_mass[observed] + (1.0 - self.mass_beta) * obs_mass
423
+ # First observation should establish the real scale immediately.
424
+ # Otherwise a beta=0.95 EMA needs many observations to climb from zero.
425
+ self.predicted_mass[observed] = torch.where(first_seen, obs_mass, ema_mass)
426
+ self.observed_count[observed] += 1.0
427
+ self.global_mass_ema = self.mass_beta * self.global_mass_ema + (1.0 - self.mass_beta) * obs_mass.mean()
428
+
429
+ # Dense audit signal. Only stale_current is allowed to use this for next-step selection.
430
+ self.last_full_mass = mass.detach().clone()
431
+
432
+ coverage = float((self.observed_count > 0).float().mean().item())
433
+ avg_obs_count = float(self.observed_count.mean().item())
434
+ new_active_fraction = float((new_active.float().mean()).item())
435
+
436
+ return {
437
+ "cosine": cosine,
438
+ "norm_ratio": norm_ratio,
439
+ "top20_mass": top20_mass,
440
+ "jacc_oracle": jacc,
441
+ "stability": stability,
442
+ "active_fraction_real": float(active.float().mean().item()),
443
+ "coverage": coverage,
444
+ "avg_obs_count": avg_obs_count,
445
+ "new_active_fraction": new_active_fraction,
446
+ }
447
+
448
+ def row_mask_for(self, module: nn.Linear) -> Optional[torch.Tensor]:
449
+ return self.row_masks.get(module)
450
+
451
+
452
+ # -----------------------------
453
+ # Masked Adam
454
+ # -----------------------------
455
+
456
+ class MaskedAdam:
457
+ def __init__(
458
+ self,
459
+ model: nn.Module,
460
+ masker: Optional[RowMasker],
461
+ lr: float,
462
+ betas=(0.9, 0.95),
463
+ eps=1e-8,
464
+ weight_decay=0.0,
465
+ freeze_non_linear_when_sparse: bool = False,
466
+ ):
467
+ self.model = model
468
+ self.masker = masker
469
+ self.lr = lr
470
+ self.beta1, self.beta2 = betas
471
+ self.eps = eps
472
+ self.weight_decay = weight_decay
473
+ self.freeze_non_linear_when_sparse = freeze_non_linear_when_sparse
474
+ self.state: Dict[nn.Parameter, Dict[str, torch.Tensor]] = {}
475
+ self.linear_param: Dict[nn.Parameter, Tuple[nn.Linear, str]] = {}
476
+ for _, m in named_linear_modules(model):
477
+ self.linear_param[m.weight] = (m, "weight")
478
+ if m.bias is not None:
479
+ self.linear_param[m.bias] = (m, "bias")
480
+
481
+ def zero_grad(self) -> None:
482
+ for p in self.model.parameters():
483
+ p.grad = None
484
+
485
+ @torch.no_grad()
486
+ def step(self) -> None:
487
+ for p in self.model.parameters():
488
+ if p.grad is None:
489
+ continue
490
+ if self.masker is not None and self.freeze_non_linear_when_sparse and p not in self.linear_param:
491
+ # Optional stricter mode: freeze embeddings/layernorm/etc. in sparse runs.
492
+ continue
493
+
494
+ if p not in self.state:
495
+ self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
496
+ m = self.state[p]["m"]
497
+ v = self.state[p]["v"]
498
+ g = p.grad
499
+ if self.weight_decay:
500
+ g = g.add(p, alpha=self.weight_decay)
501
+
502
+ row_mask = None
503
+ if self.masker is not None and p in self.linear_param:
504
+ module, kind = self.linear_param[p]
505
+ base = self.masker.row_mask_for(module)
506
+ if base is not None:
507
+ row_mask = base.view(-1, *([1] * (p.ndim - 1))) if kind == "weight" else base
508
+
509
+ if row_mask is None:
510
+ m.mul_(self.beta1).add_(g, alpha=1.0 - self.beta1)
511
+ v.mul_(self.beta2).addcmul_(g, g, value=1.0 - self.beta2)
512
+ p.add_(m / (torch.sqrt(v) + self.eps), alpha=-self.lr)
513
+ else:
514
+ mask = row_mask.expand_as(p)
515
+ if not bool(mask.any().item()):
516
+ continue
517
+ new_m = self.beta1 * m + (1.0 - self.beta1) * g
518
+ new_v = self.beta2 * v + (1.0 - self.beta2) * g * g
519
+ m[mask] = new_m[mask]
520
+ v[mask] = new_v[mask]
521
+ update = m / (torch.sqrt(v) + self.eps)
522
+ p[mask] = p[mask] - self.lr * update[mask]
523
+
524
+
525
+ # -----------------------------
526
+ # Training
527
+ # -----------------------------
528
+
529
+ @torch.no_grad()
530
+ def estimate_loss(model: nn.Module, corpus: CharCorpus, batch_size: int, eval_iters: int) -> Dict[str, float]:
531
+ model.eval()
532
+ out = {}
533
+ for split in ["train", "val"]:
534
+ losses = []
535
+ for _ in range(eval_iters):
536
+ x, y = corpus.get_batch(split, batch_size)
537
+ _, loss = model(x, y)
538
+ losses.append(float(loss.item()))
539
+ out[split] = sum(losses) / len(losses)
540
+ model.train()
541
+ return out
542
+
543
+
544
+ def train_run(
545
+ corpus: CharCorpus,
546
+ args: argparse.Namespace,
547
+ policy: Optional[Policy],
548
+ active_fraction: float,
549
+ warmup_steps: int,
550
+ explore_fraction: float,
551
+ seed_offset: int,
552
+ ) -> Dict[str, float | str]:
553
+ set_seed(args.seed + seed_offset)
554
+ dev = corpus.device
555
+ model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev)
556
+
557
+ masker = None
558
+ if policy is not None:
559
+ masker = RowMasker(
560
+ model=model,
561
+ policy=policy,
562
+ active_fraction=active_fraction,
563
+ explore_fraction=explore_fraction,
564
+ mass_beta=args.mass_beta,
565
+ unobserved_decay=args.unobserved_decay,
566
+ warmup_steps=warmup_steps,
567
+ ucb_alpha=args.ucb_alpha,
568
+ mass_init=args.mass_init,
569
+ device=dev,
570
+ )
571
+ opt = MaskedAdam(
572
+ model,
573
+ masker,
574
+ lr=args.lr,
575
+ weight_decay=args.weight_decay,
576
+ freeze_non_linear_when_sparse=args.freeze_non_linear_when_sparse,
577
+ )
578
+
579
+ sums = {
580
+ "cosine": 0.0,
581
+ "norm_ratio": 0.0,
582
+ "top20_mass": 0.0,
583
+ "jacc_oracle": 0.0,
584
+ "stability": 0.0,
585
+ "active_fraction_real": 0.0,
586
+ "coverage": 0.0,
587
+ "avg_obs_count": 0.0,
588
+ "new_active_fraction": 0.0,
589
+ }
590
+ count = 0
591
+
592
+ for step in range(args.steps):
593
+ x, y = corpus.get_batch("train", args.batch_size)
594
+ if masker is not None:
595
+ masker.choose_pre_backward(step)
596
+ _, loss = model(x, y)
597
+ opt.zero_grad()
598
+ loss.backward()
599
+ if masker is not None:
600
+ metrics = masker.audit_and_update(step)
601
+ if step >= warmup_steps:
602
+ for k in sums:
603
+ sums[k] += metrics[k]
604
+ count += 1
605
+ opt.step()
606
+
607
+ if args.verbose and (step % args.eval_interval == 0 or step == args.steps - 1):
608
+ losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters)
609
+ name = "dense" if policy is None else policy
610
+ print(
611
+ f"{name:20s} step={step:5d} warm={warmup_steps:4d} explore={explore_fraction:.2f} "
612
+ f"train={losses['train']:.4f} val={losses['val']:.4f}"
613
+ )
614
+
615
+ losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters)
616
+ row: Dict[str, float | str] = {
617
+ "run": "dense_baseline" if policy is None else policy,
618
+ "target_active": 1.0 if policy is None else active_fraction,
619
+ "warmup": warmup_steps,
620
+ "explore": explore_fraction if policy is not None else 0.0,
621
+ "train_loss": losses["train"],
622
+ "val_loss": losses["val"],
623
+ }
624
+ if masker is None or count == 0:
625
+ row.update({
626
+ "cosine": float("nan"),
627
+ "norm_ratio": float("nan"),
628
+ "top20_mass": float("nan"),
629
+ "jacc_oracle": float("nan"),
630
+ "stability": float("nan"),
631
+ "active_fraction_real": 1.0,
632
+ "coverage": float("nan"),
633
+ "avg_obs_count": float("nan"),
634
+ "new_active_fraction": float("nan"),
635
+ })
636
+ else:
637
+ for k, v in sums.items():
638
+ row[k] = v / count
639
+ return row
640
+
641
+
642
+ def print_summary(rows: List[Dict[str, float | str]]) -> None:
643
+ print("\nSummary")
644
+ header = (
645
+ f"{'run':>22s} {'target':>7s} {'actual':>7s} {'warm':>5s} {'expl':>5s} "
646
+ f"{'val':>8s} {'train':>8s} {'cos':>7s} {'top20':>7s} {'jacc':>7s} "
647
+ f"{'stable':>7s} {'cover':>7s} {'new':>7s}"
648
+ )
649
+ print(header)
650
+ print("-" * len(header))
651
+ for r in rows:
652
+ print(
653
+ f"{str(r['run']):>22s} "
654
+ f"{float(r['target_active']):7.3f} "
655
+ f"{float(r['active_fraction_real']):7.3f} "
656
+ f"{int(float(r['warmup'])):5d} "
657
+ f"{float(r['explore']):5.2f} "
658
+ f"{float(r['val_loss']):8.4f} "
659
+ f"{float(r['train_loss']):8.4f} "
660
+ f"{float(r['cosine']):7.3f} "
661
+ f"{float(r['top20_mass']):7.3f} "
662
+ f"{float(r['jacc_oracle']):7.3f} "
663
+ f"{float(r['stability']):7.3f} "
664
+ f"{float(r['coverage']):7.3f} "
665
+ f"{float(r['new_active_fraction']):7.3f}"
666
+ )
667
+
668
+
669
+ def parse_args() -> argparse.Namespace:
670
+ p = argparse.ArgumentParser()
671
+ p.add_argument("--text_path", type=str, default=None)
672
+ p.add_argument("--synthetic_sentences", type=int, default=12000)
673
+ p.add_argument("--steps", type=int, default=1000)
674
+ p.add_argument("--quick", action="store_true")
675
+ p.add_argument("--batch_size", type=int, default=32)
676
+ p.add_argument("--block_size", type=int, default=64)
677
+ p.add_argument("--n_layer", type=int, default=2)
678
+ p.add_argument("--n_head", type=int, default=4)
679
+ p.add_argument("--n_embd", type=int, default=64)
680
+ p.add_argument("--dropout", type=float, default=0.0)
681
+ p.add_argument("--lr", type=float, default=3e-4)
682
+ p.add_argument("--weight_decay", type=float, default=0.0)
683
+ p.add_argument("--active_fractions", type=float, nargs="+", default=[0.10, 0.05, 0.02])
684
+ p.add_argument("--policies", type=str, nargs="+", default=["oracle_current", "predicted_magnitude", "ucb_magnitude", "random"])
685
+ p.add_argument("--explore_fractions", type=float, nargs="+", default=[0.10])
686
+ p.add_argument("--warmup_steps_list", type=int, nargs="+", default=[5])
687
+ p.add_argument("--mass_beta", type=float, default=0.95)
688
+ p.add_argument("--unobserved_decay", type=float, default=1.0)
689
+ p.add_argument("--mass_init", type=float, default=0.0)
690
+ p.add_argument("--ucb_alpha", type=float, default=1.0)
691
+ p.add_argument("--freeze_non_linear_when_sparse", action="store_true")
692
+ p.add_argument("--eval_interval", type=int, default=200)
693
+ p.add_argument("--eval_iters", type=int, default=20)
694
+ p.add_argument("--seed", type=int, default=7)
695
+ p.add_argument("--verbose", action="store_true")
696
+ return p.parse_args()
697
+
698
+
699
+ def main() -> None:
700
+ args = parse_args()
701
+ if args.quick:
702
+ args.steps = 60
703
+ args.eval_iters = 3
704
+ args.batch_size = 16
705
+ args.block_size = 32
706
+ args.n_layer = 1
707
+ args.n_embd = 32
708
+ args.n_head = 4
709
+ args.synthetic_sentences = 2000
710
+ args.active_fractions = [0.10, 0.02]
711
+ args.policies = ["oracle_current", "predicted_magnitude", "ucb_magnitude", "random"]
712
+ args.explore_fractions = [0.10]
713
+ args.warmup_steps_list = [0]
714
+
715
+ # Validate policy strings early.
716
+ valid = {"predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"}
717
+ for pol in args.policies:
718
+ if pol not in valid:
719
+ raise ValueError(f"Unknown policy {pol!r}. Valid policies: {sorted(valid)}")
720
+
721
+ set_seed(args.seed)
722
+ dev = device()
723
+ print(f"device={dev}")
724
+ corpus = CharCorpus(load_text(args), args.block_size, dev)
725
+ print(f"vocab_size={corpus.vocab_size} train_tokens={len(corpus.train_data)} val_tokens={len(corpus.val_data)}")
726
+ print(f"policies={args.policies}")
727
+ print(f"active_fractions={args.active_fractions}")
728
+ print(f"warmup_steps_list={args.warmup_steps_list} explore_fractions={args.explore_fractions}")
729
+ print(f"mass_init={args.mass_init} mass_beta={args.mass_beta} ucb_alpha={args.ucb_alpha}")
730
+
731
+ # Report how much of the model is governed by row masks.
732
+ tmp_model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev)
733
+ total_params, linear_params, linear_frac = parameter_fractions(tmp_model)
734
+ del tmp_model
735
+ print(f"params total={total_params} linear={linear_params} linear_fraction={linear_frac:.3f}")
736
+ if args.freeze_non_linear_when_sparse:
737
+ print("freeze_non_linear_when_sparse=True: embeddings/layernorm/etc. are frozen in sparse runs")
738
+ else:
739
+ print("freeze_non_linear_when_sparse=False: non-Linear params are still updated densely")
740
+
741
+ rows: List[Dict[str, float | str]] = []
742
+ print("\nRunning dense baseline")
743
+ rows.append(train_run(corpus, args, policy=None, active_fraction=1.0, warmup_steps=0, explore_fraction=0.0, seed_offset=0))
744
+
745
+ seed_offset = 100
746
+ for af in args.active_fractions:
747
+ for pol in args.policies:
748
+ # oracle_current and stale_current do not use explore_fraction; random does not either.
749
+ explore_values = args.explore_fractions if pol in {"predicted_magnitude", "ucb_magnitude"} else [0.0]
750
+ # Warmup matters for every sparse policy, so keep it in the loop.
751
+ for warmup in args.warmup_steps_list:
752
+ for explore in explore_values:
753
+ print(f"\nRunning policy={pol}, active_fraction={af:.3f}, warmup={warmup}, explore={explore:.2f}")
754
+ rows.append(
755
+ train_run(
756
+ corpus,
757
+ args,
758
+ policy=pol, # type: ignore[arg-type]
759
+ active_fraction=af,
760
+ warmup_steps=warmup,
761
+ explore_fraction=explore,
762
+ seed_offset=seed_offset,
763
+ )
764
+ )
765
+ seed_offset += 1
766
+
767
+ print_summary(rows)
768
+
769
+ print("\nNotes")
770
+ print(" oracle_current uses current dense gradients to choose rows; it is the true upper bound.")
771
+ print(" stale_current uses previous-step dense gradient mass; it is a renamed stale/noisy control.")
772
+ print(" predicted_magnitude uses only EMA mass from active/observed rows.")
773
+ print(" ucb_magnitude adds an uncertainty bonus for under-observed rows to improve discovery.")
774
+ print(" coverage is the fraction of Linear rows that have ever been observed/active.")
775
+ print(" new is the average fraction of rows newly observed per non-warmup step.")
776
+ print(" dense gradients are still computed for audit; this is not a wall-clock benchmark yet.")
777
+
778
+
779
+ if __name__ == "__main__":
780
+ main()
experiments/sparse_transformer_v8.py ADDED
@@ -0,0 +1,943 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sparse Transformer v8: from masked-optimizer simulation to real sparse Linear backward.
3
+
4
+ v7 showed that Transformer Linear-row gradient support is heavy-tailed and stable,
5
+ and that a practical EMA selector can nearly match an oracle selector after a tiny
6
+ warmup. But v7 still computed dense gradients and only masked the optimizer step.
7
+
8
+ v8 tests the next question:
9
+
10
+ Can the sparse row mask be moved into the Linear backward pass itself?
11
+
12
+ Backward modes
13
+ --------------
14
+ 1. masked_optimizer
15
+ v7-style control. Compute dense backward, but MaskedAdam only updates active
16
+ Linear rows. This should match the previous simulation behavior.
17
+
18
+ 2. sparse_dW_full_dX
19
+ Custom autograd Linear computes grad_weight / grad_bias only for active output
20
+ rows, while still propagating full grad_input backward. This is the conservative
21
+ real-backward mode. It targets the dW part of Linear backward only.
22
+
23
+ 3. sparse_dW_sparse_dX
24
+ Custom autograd Linear computes grad_weight only for active rows and also
25
+ propagates grad_input only through active output rows. This is the aggressive
26
+ mode. It may save more backward compute in a real kernel, but it can damage
27
+ upstream learning.
28
+
29
+ Important caveat
30
+ ----------------
31
+ This script still performs a dense audit backward pass each training step to:
32
+ - compute oracle metrics,
33
+ - support oracle_current and stale_current controls,
34
+ - update practical EMA statistics only for active/observed rows.
35
+
36
+ The actual training update in sparse_dW_* modes comes from the custom sparse
37
+ backward pass, not from the dense audit gradients. This is a correctness and
38
+ semantics experiment, not a wall-clock benchmark.
39
+
40
+ Example
41
+ -------
42
+ Smoke test:
43
+ python3 sparse_transformer_v8.py --quick
44
+
45
+ Main comparison:
46
+ python3 sparse_transformer_v8.py \
47
+ --steps 2000 \
48
+ --active_fractions 0.05 0.02 \
49
+ --warmup_steps_list 5 \
50
+ --explore_fractions 0.00 \
51
+ --policies oracle_current predicted_magnitude random \
52
+ --backward_modes masked_optimizer sparse_dW_full_dX sparse_dW_sparse_dX
53
+ """
54
+
55
+ from __future__ import annotations
56
+
57
+ import argparse
58
+ import math
59
+ import random
60
+ from typing import Dict, List, Literal, Optional, Tuple
61
+
62
+ import torch
63
+
64
+ torch.set_num_threads(1)
65
+ import torch.nn as nn
66
+ import torch.nn.functional as F
67
+
68
+ Policy = Literal["predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"]
69
+ BackwardMode = Literal["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
70
+
71
+
72
+ # -----------------------------
73
+ # Reproducibility and device
74
+ # -----------------------------
75
+
76
+ def set_seed(seed: int) -> None:
77
+ random.seed(seed)
78
+ torch.manual_seed(seed)
79
+ if torch.cuda.is_available():
80
+ torch.cuda.manual_seed_all(seed)
81
+
82
+
83
+ def device() -> str:
84
+ return "cuda" if torch.cuda.is_available() else "cpu"
85
+
86
+
87
+ def make_cpu_generator(seed: int) -> torch.Generator:
88
+ gen = torch.Generator(device="cpu")
89
+ gen.manual_seed(seed)
90
+ return gen
91
+
92
+
93
+ # -----------------------------
94
+ # Data
95
+ # -----------------------------
96
+
97
+ def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
98
+ rng = random.Random(seed)
99
+ names = ["ada", "turing", "grace", "lovelace", "noether", "shannon", "hopper", "gauss"]
100
+ verbs = ["builds", "tests", "traces", "compresses", "predicts", "routes", "writes", "measures"]
101
+ objects = ["signals", "gradients", "tokens", "circuits", "features", "masks", "errors", "states"]
102
+ adverbs = ["quietly", "boldly", "slowly", "quickly", "cleanly", "strangely", "carefully"]
103
+ clauses = [
104
+ "when the loss falls",
105
+ "after the mask shifts",
106
+ "before the model answers",
107
+ "while the signal drifts",
108
+ "if the pattern repeats",
109
+ "because the tail is noisy",
110
+ ]
111
+ symbols = ["alpha", "beta", "gamma", "delta", "omega", "sigma"]
112
+
113
+ lines: List[str] = []
114
+ for _ in range(n_sentences):
115
+ t = rng.randrange(6)
116
+ if t == 0:
117
+ line = f"{rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} {rng.choice(adverbs)}."
118
+ elif t == 1:
119
+ line = f"{rng.choice(clauses)}, {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)}."
120
+ elif t == 2:
121
+ a, b = rng.sample(symbols, 2)
122
+ line = f"rule {a}: {rng.choice(objects)} -> {rng.choice(objects)}; rule {b}: {rng.choice(objects)} -> {rng.choice(objects)}."
123
+ elif t == 3:
124
+ line = f"the {rng.choice(objects)} {rng.choice(verbs)} the {rng.choice(objects)} {rng.choice(adverbs)}."
125
+ elif t == 4:
126
+ seq = " ".join(rng.choice(symbols) for _ in range(rng.randint(2, 7)))
127
+ line = f"sequence {seq} ends when {rng.choice(names)} {rng.choice(verbs)}."
128
+ else:
129
+ line = f"if {rng.choice(objects)} rise then {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} else wait."
130
+ lines.append(line)
131
+ return "\n".join(lines) + "\n"
132
+
133
+
134
+ class CharCorpus:
135
+ def __init__(self, text: str, block_size: int, device: str):
136
+ chars = sorted(set(text))
137
+ self.stoi = {ch: i for i, ch in enumerate(chars)}
138
+ self.itos = {i: ch for ch, i in self.stoi.items()}
139
+ self.vocab_size = len(chars)
140
+ self.block_size = block_size
141
+ self.device = device
142
+
143
+ data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
144
+ split = int(0.9 * len(data))
145
+ self.train_data = data[:split]
146
+ self.val_data = data[split:]
147
+
148
+ def get_batch(
149
+ self,
150
+ split: str,
151
+ batch_size: int,
152
+ generator: Optional[torch.Generator] = None,
153
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
154
+ data = self.train_data if split == "train" else self.val_data
155
+ max_start = len(data) - self.block_size - 1
156
+ if max_start <= 0:
157
+ raise ValueError("Corpus too small for block_size")
158
+ ix = torch.randint(max_start, (batch_size,), generator=generator)
159
+ x = torch.stack([data[i : i + self.block_size] for i in ix])
160
+ y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
161
+ return x.to(self.device), y.to(self.device)
162
+
163
+
164
+ def load_text(args: argparse.Namespace) -> str:
165
+ if args.text_path:
166
+ with open(args.text_path, "r", encoding="utf-8") as f:
167
+ return f.read()
168
+ return make_synthetic_corpus(args.synthetic_sentences, args.seed)
169
+
170
+
171
+ # -----------------------------
172
+ # Sparse Linear autograd
173
+ # -----------------------------
174
+
175
+ class MaskedLinearFunction(torch.autograd.Function):
176
+ @staticmethod
177
+ def forward( # type: ignore[override]
178
+ ctx,
179
+ x: torch.Tensor,
180
+ weight: torch.Tensor,
181
+ bias: Optional[torch.Tensor],
182
+ active_rows: torch.Tensor,
183
+ sparse_dx: bool,
184
+ ) -> torch.Tensor:
185
+ ctx.save_for_backward(x, weight, active_rows)
186
+ ctx.has_bias = bias is not None
187
+ ctx.sparse_dx = bool(sparse_dx)
188
+ return F.linear(x, weight, bias)
189
+
190
+ @staticmethod
191
+ def backward(ctx, grad_y: torch.Tensor): # type: ignore[override]
192
+ x, weight, active_rows = ctx.saved_tensors
193
+ sparse_dx = bool(ctx.sparse_dx)
194
+ has_bias = bool(ctx.has_bias)
195
+
196
+ x_shape = x.shape
197
+ x_flat = x.reshape(-1, x.shape[-1])
198
+ gy_flat = grad_y.reshape(-1, grad_y.shape[-1])
199
+
200
+ active_idx = torch.nonzero(active_rows, as_tuple=False).flatten()
201
+
202
+ grad_weight = torch.zeros_like(weight)
203
+ grad_bias = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if has_bias else None
204
+
205
+ if active_idx.numel() > 0:
206
+ gy_active = gy_flat[:, active_idx]
207
+ grad_weight[active_idx] = gy_active.transpose(0, 1) @ x_flat
208
+ if grad_bias is not None:
209
+ grad_bias[active_idx] = gy_active.sum(dim=0)
210
+
211
+ if sparse_dx:
212
+ grad_x_flat = gy_active @ weight[active_idx]
213
+ else:
214
+ grad_x_flat = gy_flat @ weight
215
+ else:
216
+ # This can happen when a global top-k mask selects no rows from a
217
+ # particular layer. Conservative full_dX still propagates through that
218
+ # layer; aggressive sparse_dX cuts it off for that layer.
219
+ if sparse_dx:
220
+ grad_x_flat = torch.zeros_like(x_flat)
221
+ else:
222
+ grad_x_flat = gy_flat @ weight
223
+
224
+ grad_x = grad_x_flat.reshape(x_shape)
225
+ return grad_x, grad_weight, grad_bias, None, None
226
+
227
+
228
+ class SparseLinear(nn.Linear):
229
+ """nn.Linear with an optional row-sparse backward pass."""
230
+
231
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
232
+ super().__init__(in_features, out_features, bias=bias)
233
+ self.sparse_enabled = False
234
+ self.sparse_dx = False
235
+ self.active_rows: Optional[torch.Tensor] = None
236
+
237
+ def set_sparse_backward(self, enabled: bool, active_rows: Optional[torch.Tensor], sparse_dx: bool) -> None:
238
+ self.sparse_enabled = bool(enabled)
239
+ self.sparse_dx = bool(sparse_dx)
240
+ self.active_rows = active_rows
241
+
242
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
243
+ if not self.sparse_enabled or self.active_rows is None:
244
+ return F.linear(x, self.weight, self.bias)
245
+ return MaskedLinearFunction.apply(x, self.weight, self.bias, self.active_rows, self.sparse_dx)
246
+
247
+
248
+ # -----------------------------
249
+ # Mini GPT
250
+ # -----------------------------
251
+
252
+ class CausalSelfAttention(nn.Module):
253
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
254
+ super().__init__()
255
+ assert n_embd % n_head == 0
256
+ self.n_head = n_head
257
+ self.head_dim = n_embd // n_head
258
+ self.c_attn = SparseLinear(n_embd, 3 * n_embd)
259
+ self.c_proj = SparseLinear(n_embd, n_embd)
260
+ self.dropout = nn.Dropout(dropout)
261
+ self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
262
+
263
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
264
+ B, T, C = x.shape
265
+ qkv = self.c_attn(x)
266
+ q, k, v = qkv.split(C, dim=2)
267
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
268
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
269
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
270
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
271
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
272
+ att = F.softmax(att, dim=-1)
273
+ att = self.dropout(att)
274
+ y = att @ v
275
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
276
+ return self.c_proj(y)
277
+
278
+
279
+ class FeedForward(nn.Module):
280
+ def __init__(self, n_embd: int, dropout: float):
281
+ super().__init__()
282
+ self.c_fc = SparseLinear(n_embd, 4 * n_embd)
283
+ self.c_proj = SparseLinear(4 * n_embd, n_embd)
284
+ self.dropout = nn.Dropout(dropout)
285
+
286
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
287
+ return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
288
+
289
+
290
+ class Block(nn.Module):
291
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
292
+ super().__init__()
293
+ self.ln1 = nn.LayerNorm(n_embd)
294
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
295
+ self.ln2 = nn.LayerNorm(n_embd)
296
+ self.mlp = FeedForward(n_embd, dropout)
297
+
298
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
299
+ x = x + self.attn(self.ln1(x))
300
+ x = x + self.mlp(self.ln2(x))
301
+ return x
302
+
303
+
304
+ class MiniGPT(nn.Module):
305
+ def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float):
306
+ super().__init__()
307
+ self.block_size = block_size
308
+ self.tok_emb = nn.Embedding(vocab_size, n_embd)
309
+ self.pos_emb = nn.Embedding(block_size, n_embd)
310
+ self.drop = nn.Dropout(dropout)
311
+ self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
312
+ self.ln_f = nn.LayerNorm(n_embd)
313
+ self.lm_head = SparseLinear(n_embd, vocab_size)
314
+
315
+ def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
316
+ B, T = idx.shape
317
+ pos = torch.arange(T, device=idx.device)
318
+ x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
319
+ x = self.drop(x)
320
+ x = self.blocks(x)
321
+ x = self.ln_f(x)
322
+ logits = self.lm_head(x)
323
+ loss = None
324
+ if targets is not None:
325
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
326
+ return logits, loss
327
+
328
+
329
+ def named_sparse_linear_modules(model: nn.Module) -> List[Tuple[str, SparseLinear]]:
330
+ return [(name, m) for name, m in model.named_modules() if isinstance(m, SparseLinear)]
331
+
332
+
333
+ def parameter_fractions(model: nn.Module) -> Tuple[int, int, float]:
334
+ total = sum(p.numel() for p in model.parameters())
335
+ linear = 0
336
+ for _, m in named_sparse_linear_modules(model):
337
+ linear += m.weight.numel()
338
+ if m.bias is not None:
339
+ linear += m.bias.numel()
340
+ return total, linear, linear / max(1, total)
341
+
342
+
343
+ def configure_sparse_linears(
344
+ model: nn.Module,
345
+ masker: Optional["RowMasker"],
346
+ enabled: bool,
347
+ backward_mode: Optional[str],
348
+ ) -> None:
349
+ sparse_dx = backward_mode == "sparse_dW_sparse_dX"
350
+ for _, m in named_sparse_linear_modules(model):
351
+ active = masker.row_mask_for(m) if masker is not None else None
352
+ m.set_sparse_backward(enabled=enabled, active_rows=active, sparse_dx=sparse_dx)
353
+
354
+
355
+ # -----------------------------
356
+ # Mask selector
357
+ # -----------------------------
358
+
359
+ class RowMasker:
360
+ def __init__(
361
+ self,
362
+ model: nn.Module,
363
+ policy: Policy,
364
+ active_fraction: float,
365
+ explore_fraction: float,
366
+ mass_beta: float,
367
+ unobserved_decay: float,
368
+ warmup_steps: int,
369
+ ucb_alpha: float,
370
+ mass_init: float,
371
+ device: str,
372
+ ):
373
+ self.model = model
374
+ self.policy = policy
375
+ self.active_fraction = active_fraction
376
+ self.explore_fraction = explore_fraction
377
+ self.mass_beta = mass_beta
378
+ self.unobserved_decay = unobserved_decay
379
+ self.warmup_steps = warmup_steps
380
+ self.ucb_alpha = ucb_alpha
381
+ self.mass_init = mass_init
382
+ self.device = device
383
+ self.step_index = 0
384
+
385
+ self.linear_modules = [m for _, m in named_sparse_linear_modules(model)]
386
+ self.module_to_ids: Dict[SparseLinear, torch.Tensor] = {}
387
+ ids = []
388
+ offset = 0
389
+ for m in self.linear_modules:
390
+ n = m.weight.shape[0]
391
+ block_ids = torch.arange(offset, offset + n, device=device)
392
+ self.module_to_ids[m] = block_ids
393
+ ids.append(block_ids)
394
+ offset += n
395
+ self.n_blocks = offset
396
+
397
+ self.predicted_mass = torch.full((self.n_blocks,), mass_init, device=device)
398
+ self.last_full_mass = torch.full((self.n_blocks,), mass_init, device=device)
399
+ self.observed_count = torch.zeros(self.n_blocks, device=device)
400
+ self.global_mass_ema = torch.tensor(max(mass_init, 1e-6), device=device)
401
+
402
+ self.prev_active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device)
403
+ self.active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device)
404
+ self.row_masks: Dict[SparseLinear, torch.Tensor] = {
405
+ m: torch.zeros(m.weight.shape[0], dtype=torch.bool, device=device) for m in self.linear_modules
406
+ }
407
+
408
+ def _topk_mask(self, values: torch.Tensor, fraction: float) -> torch.Tensor:
409
+ k = max(1, int(fraction * values.numel()))
410
+ mask = torch.zeros_like(values, dtype=torch.bool)
411
+ noisy = values + 1e-9 * torch.rand_like(values)
412
+ mask[torch.topk(noisy, k=k).indices] = True
413
+ return mask
414
+
415
+ @staticmethod
416
+ def _jaccard(a: torch.Tensor, b: torch.Tensor) -> float:
417
+ inter = (a & b).sum().float()
418
+ union = (a | b).sum().float()
419
+ return float((inter / torch.clamp(union, min=1.0)).item())
420
+
421
+ def _set_active(self, active: torch.Tensor) -> None:
422
+ self.active = active
423
+ self.row_masks = {}
424
+ for m, ids in self.module_to_ids.items():
425
+ self.row_masks[m] = active[ids]
426
+
427
+ def _sample_exploit_explore(self, scores: torch.Tensor) -> torch.Tensor:
428
+ n = self.n_blocks
429
+ k_total = max(1, int(self.active_fraction * n))
430
+ k_explore = min(k_total, max(0, int(self.explore_fraction * k_total)))
431
+ k_exploit = k_total - k_explore
432
+ active = torch.zeros(n, dtype=torch.bool, device=self.device)
433
+
434
+ if k_exploit > 0:
435
+ active[torch.topk(scores + 1e-9 * torch.rand_like(scores), k=k_exploit).indices] = True
436
+ if k_explore > 0:
437
+ remaining = torch.nonzero(~active, as_tuple=False).flatten()
438
+ pick = remaining[torch.randperm(remaining.numel(), device=self.device)[:k_explore]]
439
+ active[pick] = True
440
+ return active
441
+
442
+ def choose_pre_backward(self, step: int) -> None:
443
+ self.step_index = step
444
+ if step < self.warmup_steps:
445
+ self._set_active(torch.ones(self.n_blocks, dtype=torch.bool, device=self.device))
446
+ return
447
+
448
+ if self.policy == "oracle_current":
449
+ # Oracle cannot choose until the dense audit gradient is known.
450
+ self._set_active(torch.zeros(self.n_blocks, dtype=torch.bool, device=self.device))
451
+ return
452
+
453
+ if self.policy == "random":
454
+ self._set_active(self._sample_exploit_explore(torch.rand(self.n_blocks, device=self.device)))
455
+ return
456
+
457
+ if self.policy == "stale_current":
458
+ self._set_active(self._topk_mask(self.last_full_mass, self.active_fraction))
459
+ return
460
+
461
+ if self.policy == "predicted_magnitude":
462
+ self._set_active(self._sample_exploit_explore(self.predicted_mass))
463
+ return
464
+
465
+ if self.policy == "ucb_magnitude":
466
+ t = max(1, step - self.warmup_steps + 1)
467
+ log_term = torch.log(torch.tensor(float(t + 2), device=self.device))
468
+ bonus_scale = torch.clamp(self.global_mass_ema, min=1e-8)
469
+ bonus = self.ucb_alpha * bonus_scale * torch.sqrt(log_term / (self.observed_count + 1.0))
470
+ self._set_active(self._sample_exploit_explore(self.predicted_mass + bonus))
471
+ return
472
+
473
+ raise ValueError(f"Unknown policy: {self.policy}")
474
+
475
+ @torch.no_grad()
476
+ def current_gradient_mass_from_grads(self) -> torch.Tensor:
477
+ mass = torch.zeros(self.n_blocks, device=self.device)
478
+ for m, ids in self.module_to_ids.items():
479
+ if m.weight.grad is None:
480
+ continue
481
+ row_sq = m.weight.grad.square().sum(dim=1)
482
+ if m.bias is not None and m.bias.grad is not None:
483
+ row_sq = row_sq + m.bias.grad.square()
484
+ mass[ids] = torch.sqrt(row_sq + 1e-30)
485
+ return mass
486
+
487
+ @torch.no_grad()
488
+ def audit_and_update_from_mass(self, step: int, mass: torch.Tensor) -> Dict[str, float]:
489
+ if step < self.warmup_steps:
490
+ active = torch.ones(self.n_blocks, dtype=torch.bool, device=self.device)
491
+ self._set_active(active)
492
+ elif self.policy == "oracle_current":
493
+ active = self._topk_mask(mass, self.active_fraction)
494
+ self._set_active(active)
495
+ else:
496
+ active = self.active
497
+
498
+ true_sq = mass.square().sum()
499
+ approx_sq = mass[active].square().sum()
500
+ cosine = float((torch.sqrt(approx_sq + 1e-30) / torch.sqrt(true_sq + 1e-30)).item())
501
+
502
+ oracle_mask = self._topk_mask(mass, self.active_fraction)
503
+ jacc = self._jaccard(active, oracle_mask)
504
+ stability = self._jaccard(active, self.prev_active)
505
+ self.prev_active = active.clone()
506
+
507
+ k20 = max(1, int(0.2 * self.n_blocks))
508
+ sorted_mass = torch.sort(mass, descending=True).values
509
+ top20_mass = float((sorted_mass[:k20].sum() / (sorted_mass.sum() + 1e-12)).item())
510
+
511
+ new_active = active & (self.observed_count == 0)
512
+
513
+ # Practical rule: update predicted statistics only for active/observed rows.
514
+ self.predicted_mass.mul_(self.unobserved_decay)
515
+ observed = active
516
+ if bool(observed.any().item()):
517
+ obs_mass = mass[observed]
518
+ first_seen = self.observed_count[observed] == 0
519
+ ema_mass = self.mass_beta * self.predicted_mass[observed] + (1.0 - self.mass_beta) * obs_mass
520
+ self.predicted_mass[observed] = torch.where(first_seen, obs_mass, ema_mass)
521
+ self.observed_count[observed] += 1.0
522
+ self.global_mass_ema = self.mass_beta * self.global_mass_ema + (1.0 - self.mass_beta) * obs_mass.mean()
523
+
524
+ # Dense audit signal; only stale_current is allowed to use this for selection.
525
+ self.last_full_mass = mass.detach().clone()
526
+
527
+ return {
528
+ "cosine": cosine,
529
+ "norm_ratio": cosine,
530
+ "top20_mass": top20_mass,
531
+ "jacc_oracle": jacc,
532
+ "stability": stability,
533
+ "active_fraction_real": float(active.float().mean().item()),
534
+ "coverage": float((self.observed_count > 0).float().mean().item()),
535
+ "avg_obs_count": float(self.observed_count.mean().item()),
536
+ "new_active_fraction": float(new_active.float().mean().item()),
537
+ }
538
+
539
+ def row_mask_for(self, module: SparseLinear) -> Optional[torch.Tensor]:
540
+ return self.row_masks.get(module)
541
+
542
+
543
+ # -----------------------------
544
+ # Masked Adam
545
+ # -----------------------------
546
+
547
+ class MaskedAdam:
548
+ def __init__(
549
+ self,
550
+ model: nn.Module,
551
+ masker: Optional[RowMasker],
552
+ lr: float,
553
+ betas=(0.9, 0.95),
554
+ eps=1e-8,
555
+ weight_decay=0.0,
556
+ freeze_non_linear_when_sparse: bool = False,
557
+ ):
558
+ self.model = model
559
+ self.masker = masker
560
+ self.lr = lr
561
+ self.beta1, self.beta2 = betas
562
+ self.eps = eps
563
+ self.weight_decay = weight_decay
564
+ self.freeze_non_linear_when_sparse = freeze_non_linear_when_sparse
565
+ self.state: Dict[nn.Parameter, Dict[str, torch.Tensor]] = {}
566
+ self.linear_param: Dict[nn.Parameter, Tuple[SparseLinear, str]] = {}
567
+ for _, m in named_sparse_linear_modules(model):
568
+ self.linear_param[m.weight] = (m, "weight")
569
+ if m.bias is not None:
570
+ self.linear_param[m.bias] = (m, "bias")
571
+
572
+ def zero_grad(self) -> None:
573
+ for p in self.model.parameters():
574
+ p.grad = None
575
+
576
+ @torch.no_grad()
577
+ def step(self) -> None:
578
+ for p in self.model.parameters():
579
+ if p.grad is None:
580
+ continue
581
+ if self.masker is not None and self.freeze_non_linear_when_sparse and p not in self.linear_param:
582
+ continue
583
+
584
+ if p not in self.state:
585
+ self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
586
+ m = self.state[p]["m"]
587
+ v = self.state[p]["v"]
588
+ g = p.grad
589
+ if self.weight_decay:
590
+ g = g.add(p, alpha=self.weight_decay)
591
+
592
+ row_mask = None
593
+ if self.masker is not None and p in self.linear_param:
594
+ module, kind = self.linear_param[p]
595
+ base = self.masker.row_mask_for(module)
596
+ if base is not None:
597
+ row_mask = base.view(-1, *([1] * (p.ndim - 1))) if kind == "weight" else base
598
+
599
+ if row_mask is None:
600
+ m.mul_(self.beta1).add_(g, alpha=1.0 - self.beta1)
601
+ v.mul_(self.beta2).addcmul_(g, g, value=1.0 - self.beta2)
602
+ p.add_(m / (torch.sqrt(v) + self.eps), alpha=-self.lr)
603
+ else:
604
+ mask = row_mask.expand_as(p)
605
+ if not bool(mask.any().item()):
606
+ continue
607
+ new_m = self.beta1 * m + (1.0 - self.beta1) * g
608
+ new_v = self.beta2 * v + (1.0 - self.beta2) * g * g
609
+ m[mask] = new_m[mask]
610
+ v[mask] = new_v[mask]
611
+ update = m / (torch.sqrt(v) + self.eps)
612
+ p[mask] = p[mask] - self.lr * update[mask]
613
+
614
+
615
+ # -----------------------------
616
+ # Training utilities
617
+ # -----------------------------
618
+
619
+ @torch.no_grad()
620
+ def estimate_loss(model: nn.Module, corpus: CharCorpus, batch_size: int, eval_iters: int, seed: int) -> Dict[str, float]:
621
+ model.eval()
622
+ configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
623
+ out = {}
624
+ for split in ["train", "val"]:
625
+ losses = []
626
+ gen = make_cpu_generator(seed + (0 if split == "train" else 100000))
627
+ for _ in range(eval_iters):
628
+ x, y = corpus.get_batch(split, batch_size, generator=gen)
629
+ _, loss = model(x, y)
630
+ losses.append(float(loss.item()))
631
+ out[split] = sum(losses) / len(losses)
632
+ model.train()
633
+ return out
634
+
635
+
636
+ def dense_audit_pass(model: nn.Module, corpus_batch: Tuple[torch.Tensor, torch.Tensor], opt: MaskedAdam, masker: RowMasker) -> torch.Tensor:
637
+ x, y = corpus_batch
638
+ configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
639
+ opt.zero_grad()
640
+ _, audit_loss = model(x, y)
641
+ audit_loss.backward()
642
+ mass = masker.current_gradient_mass_from_grads()
643
+ opt.zero_grad()
644
+ return mass
645
+
646
+
647
+ def sparse_training_backward(
648
+ model: nn.Module,
649
+ corpus_batch: Tuple[torch.Tensor, torch.Tensor],
650
+ opt: MaskedAdam,
651
+ masker: Optional[RowMasker],
652
+ backward_mode: Optional[BackwardMode],
653
+ ) -> float:
654
+ x, y = corpus_batch
655
+ opt.zero_grad()
656
+
657
+ if masker is None or backward_mode is None or backward_mode == "masked_optimizer":
658
+ configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
659
+ else:
660
+ configure_sparse_linears(model, masker=masker, enabled=True, backward_mode=backward_mode)
661
+
662
+ _, loss = model(x, y)
663
+ loss.backward()
664
+ configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
665
+ return float(loss.item())
666
+
667
+
668
+ def train_run(
669
+ corpus: CharCorpus,
670
+ args: argparse.Namespace,
671
+ policy: Optional[Policy],
672
+ backward_mode: Optional[BackwardMode],
673
+ active_fraction: float,
674
+ warmup_steps: int,
675
+ explore_fraction: float,
676
+ seed_offset: int,
677
+ ) -> Dict[str, float | str]:
678
+ # Same model initialization and same minibatch sequence for every run by default.
679
+ set_seed(args.seed + (seed_offset if args.unpaired_seeds else 0))
680
+ data_gen = make_cpu_generator(args.seed + 12345)
681
+
682
+ dev = corpus.device
683
+ model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev)
684
+
685
+ masker = None
686
+ if policy is not None:
687
+ masker = RowMasker(
688
+ model=model,
689
+ policy=policy,
690
+ active_fraction=active_fraction,
691
+ explore_fraction=explore_fraction,
692
+ mass_beta=args.mass_beta,
693
+ unobserved_decay=args.unobserved_decay,
694
+ warmup_steps=warmup_steps,
695
+ ucb_alpha=args.ucb_alpha,
696
+ mass_init=args.mass_init,
697
+ device=dev,
698
+ )
699
+ opt = MaskedAdam(
700
+ model,
701
+ masker,
702
+ lr=args.lr,
703
+ weight_decay=args.weight_decay,
704
+ freeze_non_linear_when_sparse=args.freeze_non_linear_when_sparse,
705
+ )
706
+
707
+ sums = {
708
+ "cosine": 0.0,
709
+ "norm_ratio": 0.0,
710
+ "top20_mass": 0.0,
711
+ "jacc_oracle": 0.0,
712
+ "stability": 0.0,
713
+ "active_fraction_real": 0.0,
714
+ "coverage": 0.0,
715
+ "avg_obs_count": 0.0,
716
+ "new_active_fraction": 0.0,
717
+ }
718
+ count = 0
719
+
720
+ for step in range(args.steps):
721
+ batch = corpus.get_batch("train", args.batch_size, generator=data_gen)
722
+
723
+ if masker is None:
724
+ loss_value = sparse_training_backward(model, batch, opt, masker=None, backward_mode=None)
725
+ opt.step()
726
+ else:
727
+ masker.choose_pre_backward(step)
728
+ full_mass = dense_audit_pass(model, batch, opt, masker)
729
+ metrics = masker.audit_and_update_from_mass(step, full_mass)
730
+ if step >= warmup_steps:
731
+ for k in sums:
732
+ sums[k] += metrics[k]
733
+ count += 1
734
+
735
+ loss_value = sparse_training_backward(model, batch, opt, masker=masker, backward_mode=backward_mode)
736
+ opt.step()
737
+
738
+ if args.verbose and (step % args.eval_interval == 0 or step == args.steps - 1):
739
+ losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters, seed=args.seed + 555)
740
+ name = "dense" if policy is None else f"{policy}/{backward_mode}"
741
+ print(
742
+ f"{name:38s} step={step:5d} warm={warmup_steps:4d} explore={explore_fraction:.2f} "
743
+ f"loss={loss_value:.4f} train={losses['train']:.4f} val={losses['val']:.4f}"
744
+ )
745
+
746
+ losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters, seed=args.seed + 999)
747
+ row: Dict[str, float | str] = {
748
+ "run": "dense_baseline" if policy is None else policy,
749
+ "mode": "dense" if backward_mode is None else backward_mode,
750
+ "target_active": 1.0 if policy is None else active_fraction,
751
+ "warmup": warmup_steps,
752
+ "explore": explore_fraction if policy is not None else 0.0,
753
+ "train_loss": losses["train"],
754
+ "val_loss": losses["val"],
755
+ }
756
+ if masker is None or count == 0:
757
+ row.update({
758
+ "cosine": float("nan"),
759
+ "norm_ratio": float("nan"),
760
+ "top20_mass": float("nan"),
761
+ "jacc_oracle": float("nan"),
762
+ "stability": float("nan"),
763
+ "active_fraction_real": 1.0,
764
+ "coverage": float("nan"),
765
+ "avg_obs_count": float("nan"),
766
+ "new_active_fraction": float("nan"),
767
+ })
768
+ else:
769
+ for k, v in sums.items():
770
+ row[k] = v / count
771
+ return row
772
+
773
+
774
+ def print_summary(rows: List[Dict[str, float | str]]) -> None:
775
+ print("\nSummary")
776
+ header = (
777
+ f"{'run':>22s} {'mode':>19s} {'target':>7s} {'actual':>7s} {'warm':>5s} {'expl':>5s} "
778
+ f"{'val':>8s} {'train':>8s} {'cos':>7s} {'top20':>7s} {'jacc':>7s} "
779
+ f"{'stable':>7s} {'cover':>7s} {'new':>7s}"
780
+ )
781
+ print(header)
782
+ print("-" * len(header))
783
+ for r in rows:
784
+ print(
785
+ f"{str(r['run']):>22s} "
786
+ f"{str(r['mode']):>19s} "
787
+ f"{float(r['target_active']):7.3f} "
788
+ f"{float(r['active_fraction_real']):7.3f} "
789
+ f"{int(float(r['warmup'])):5d} "
790
+ f"{float(r['explore']):5.2f} "
791
+ f"{float(r['val_loss']):8.4f} "
792
+ f"{float(r['train_loss']):8.4f} "
793
+ f"{float(r['cosine']):7.3f} "
794
+ f"{float(r['top20_mass']):7.3f} "
795
+ f"{float(r['jacc_oracle']):7.3f} "
796
+ f"{float(r['stability']):7.3f} "
797
+ f"{float(r['coverage']):7.3f} "
798
+ f"{float(r['new_active_fraction']):7.3f}"
799
+ )
800
+
801
+
802
+ def parse_args() -> argparse.Namespace:
803
+ p = argparse.ArgumentParser()
804
+ p.add_argument("--text_path", type=str, default=None)
805
+ p.add_argument("--synthetic_sentences", type=int, default=12000)
806
+ p.add_argument("--steps", type=int, default=1000)
807
+ p.add_argument("--quick", action="store_true")
808
+ p.add_argument("--batch_size", type=int, default=32)
809
+ p.add_argument("--block_size", type=int, default=64)
810
+ p.add_argument("--n_layer", type=int, default=2)
811
+ p.add_argument("--n_head", type=int, default=4)
812
+ p.add_argument("--n_embd", type=int, default=64)
813
+ p.add_argument("--dropout", type=float, default=0.0)
814
+ p.add_argument("--lr", type=float, default=3e-4)
815
+ p.add_argument("--weight_decay", type=float, default=0.0)
816
+ p.add_argument("--active_fractions", type=float, nargs="+", default=[0.05, 0.02])
817
+ p.add_argument("--policies", type=str, nargs="+", default=["oracle_current", "predicted_magnitude", "random"])
818
+ p.add_argument(
819
+ "--backward_modes",
820
+ type=str,
821
+ nargs="+",
822
+ default=["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"],
823
+ )
824
+ p.add_argument("--explore_fractions", type=float, nargs="+", default=[0.0])
825
+ p.add_argument("--warmup_steps_list", type=int, nargs="+", default=[5])
826
+ p.add_argument("--mass_beta", type=float, default=0.95)
827
+ p.add_argument("--unobserved_decay", type=float, default=1.0)
828
+ p.add_argument("--mass_init", type=float, default=0.0)
829
+ p.add_argument("--ucb_alpha", type=float, default=1.0)
830
+ p.add_argument("--freeze_non_linear_when_sparse", action="store_true")
831
+ p.add_argument("--eval_interval", type=int, default=200)
832
+ p.add_argument("--eval_iters", type=int, default=20)
833
+ p.add_argument("--seed", type=int, default=7)
834
+ p.add_argument("--unpaired_seeds", action="store_true", help="Use different init seeds per run instead of paired seeds.")
835
+ p.add_argument("--verbose", action="store_true")
836
+ return p.parse_args()
837
+
838
+
839
+ def main() -> None:
840
+ args = parse_args()
841
+ if args.quick:
842
+ args.steps = 40
843
+ args.eval_iters = 2
844
+ args.batch_size = 8
845
+ args.block_size = 32
846
+ args.n_layer = 1
847
+ args.n_embd = 32
848
+ args.n_head = 4
849
+ args.synthetic_sentences = 1200
850
+ args.active_fractions = [0.05]
851
+ args.policies = ["predicted_magnitude", "random"]
852
+ args.backward_modes = ["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
853
+ args.explore_fractions = [0.0]
854
+ args.warmup_steps_list = [5]
855
+
856
+ valid_policies = {"predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"}
857
+ valid_modes = {"masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"}
858
+ for pol in args.policies:
859
+ if pol not in valid_policies:
860
+ raise ValueError(f"Unknown policy {pol!r}. Valid policies: {sorted(valid_policies)}")
861
+ for mode in args.backward_modes:
862
+ if mode not in valid_modes:
863
+ raise ValueError(f"Unknown backward mode {mode!r}. Valid modes: {sorted(valid_modes)}")
864
+
865
+ set_seed(args.seed)
866
+ dev = device()
867
+ print(f"device={dev}")
868
+ corpus = CharCorpus(load_text(args), args.block_size, dev)
869
+ print(f"vocab_size={corpus.vocab_size} train_tokens={len(corpus.train_data)} val_tokens={len(corpus.val_data)}")
870
+ print(f"policies={args.policies}")
871
+ print(f"backward_modes={args.backward_modes}")
872
+ print(f"active_fractions={args.active_fractions}")
873
+ print(f"warmup_steps_list={args.warmup_steps_list} explore_fractions={args.explore_fractions}")
874
+ print(f"mass_init={args.mass_init} mass_beta={args.mass_beta} ucb_alpha={args.ucb_alpha}")
875
+ print(f"paired_seeds={not args.unpaired_seeds}")
876
+
877
+ tmp_model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev)
878
+ total_params, linear_params, linear_frac = parameter_fractions(tmp_model)
879
+ del tmp_model
880
+ print(f"params total={total_params} linear={linear_params} linear_fraction={linear_frac:.3f}")
881
+ if args.freeze_non_linear_when_sparse:
882
+ print("freeze_non_linear_when_sparse=True: embeddings/layernorm/etc. are frozen in sparse runs")
883
+ else:
884
+ print("freeze_non_linear_when_sparse=False: non-Linear params are still updated densely")
885
+
886
+ if args.dropout != 0.0:
887
+ print("warning: dropout is nonzero; dense audit and sparse training passes may see different dropout masks")
888
+
889
+ rows: List[Dict[str, float | str]] = []
890
+ print("\nRunning dense baseline")
891
+ rows.append(
892
+ train_run(
893
+ corpus,
894
+ args,
895
+ policy=None,
896
+ backward_mode=None,
897
+ active_fraction=1.0,
898
+ warmup_steps=0,
899
+ explore_fraction=0.0,
900
+ seed_offset=0,
901
+ )
902
+ )
903
+
904
+ seed_offset = 100
905
+ for mode in args.backward_modes:
906
+ for af in args.active_fractions:
907
+ for pol in args.policies:
908
+ explore_values = args.explore_fractions if pol in {"predicted_magnitude", "ucb_magnitude"} else [0.0]
909
+ for warmup in args.warmup_steps_list:
910
+ for explore in explore_values:
911
+ print(
912
+ f"\nRunning mode={mode}, policy={pol}, "
913
+ f"active_fraction={af:.3f}, warmup={warmup}, explore={explore:.2f}"
914
+ )
915
+ rows.append(
916
+ train_run(
917
+ corpus,
918
+ args,
919
+ policy=pol, # type: ignore[arg-type]
920
+ backward_mode=mode, # type: ignore[arg-type]
921
+ active_fraction=af,
922
+ warmup_steps=warmup,
923
+ explore_fraction=explore,
924
+ seed_offset=seed_offset,
925
+ )
926
+ )
927
+ seed_offset += 1
928
+
929
+ print_summary(rows)
930
+
931
+ print("\nNotes")
932
+ print(" masked_optimizer is the v7-style dense-backward simulation control.")
933
+ print(" sparse_dW_full_dX uses custom Linear backward: sparse weight/bias grads, full input gradient.")
934
+ print(" sparse_dW_sparse_dX uses custom Linear backward: sparse weight/bias grads and sparse input gradient.")
935
+ print(" oracle_current uses dense audit gradients to choose rows; it is an upper bound.")
936
+ print(" predicted_magnitude uses EMA mass from active/observed rows only.")
937
+ print(" random is the sparse-support control.")
938
+ print(" dense audit gradients are still computed every step for metrics/control; this is not a speed benchmark.")
939
+ print(" The key comparison is masked_optimizer vs sparse_dW_full_dX. If they match, the v7 effect survives real dW sparsification.")
940
+
941
+
942
+ if __name__ == "__main__":
943
+ main()
experiments/sparse_transformer_v9.py ADDED
@@ -0,0 +1,1042 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sparse Transformer v9: no-audit sparse training after dense warmup.
3
+
4
+ v8 proved that the row-sparse mask can be moved into a custom Linear backward.
5
+ v9 removes the remaining dense-audit crutch.
6
+
7
+ Default behavior
8
+ ----------------
9
+ 1. Run a short dense warmup, usually 5 steps.
10
+ 2. Initialize the EMA row-importance predictor from those dense warmup gradients.
11
+ 3. After warmup, choose active rows from the predictor.
12
+ 4. Train using sparse backward.
13
+ 5. Update EMA statistics only from rows that were actually active/observed.
14
+ 6. Do not compute dense gradients unless --audit_every > 0.
15
+
16
+ Audit behavior
17
+ --------------
18
+ --audit_every 0
19
+ No dense audit after warmup. Cosine/Jaccard/top20 are unavailable and show as nan.
20
+
21
+ --audit_every N
22
+ Every N steps, run an extra dense backward pass on the same batch only to
23
+ measure cosine/top20/Jaccard. The audit is NOT used to update the selector,
24
+ except for oracle_current, which is explicitly an upper-bound control.
25
+
26
+ This is still not a wall-clock benchmark on vanilla PyTorch/MPS/CPU. The custom
27
+ backward uses indexing and ordinary PyTorch matmuls. The goal is to verify that
28
+ the method survives without dense information after warmup.
29
+
30
+ Examples
31
+ --------
32
+ No-audit practical run:
33
+ python3 sparse_transformer_v9.py \
34
+ --device mps \
35
+ --steps 2000 \
36
+ --active_fractions 0.05 0.02 \
37
+ --warmup_steps_list 5 \
38
+ --policies predicted_magnitude random \
39
+ --backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
40
+ --audit_every 0
41
+
42
+ Occasional audit for measurement only:
43
+ python3 sparse_transformer_v9.py \
44
+ --steps 2000 \
45
+ --active_fractions 0.05 0.02 \
46
+ --warmup_steps_list 5 \
47
+ --policies predicted_magnitude random \
48
+ --backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
49
+ --audit_every 100
50
+ """
51
+
52
+
53
+ from __future__ import annotations
54
+
55
+ import argparse
56
+ import math
57
+ import random
58
+ from typing import Dict, List, Literal, Optional, Tuple
59
+
60
+ import torch
61
+
62
+ torch.set_num_threads(1)
63
+ import torch.nn as nn
64
+ import torch.nn.functional as F
65
+
66
+ Policy = Literal["predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"]
67
+ BackwardMode = Literal["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
68
+
69
+
70
+ # -----------------------------
71
+ # Reproducibility and device
72
+ # -----------------------------
73
+
74
+ def set_seed(seed: int) -> None:
75
+ random.seed(seed)
76
+ torch.manual_seed(seed)
77
+ if torch.cuda.is_available():
78
+ torch.cuda.manual_seed_all(seed)
79
+
80
+
81
+ def default_device() -> str:
82
+ if torch.cuda.is_available():
83
+ return "cuda"
84
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
85
+ return "mps"
86
+ return "cpu"
87
+
88
+
89
+ def make_cpu_generator(seed: int) -> torch.Generator:
90
+ gen = torch.Generator(device="cpu")
91
+ gen.manual_seed(seed)
92
+ return gen
93
+
94
+
95
+ # -----------------------------
96
+ # Data
97
+ # -----------------------------
98
+
99
+ def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
100
+ rng = random.Random(seed)
101
+ names = ["ada", "turing", "grace", "lovelace", "noether", "shannon", "hopper", "gauss"]
102
+ verbs = ["builds", "tests", "traces", "compresses", "predicts", "routes", "writes", "measures"]
103
+ objects = ["signals", "gradients", "tokens", "circuits", "features", "masks", "errors", "states"]
104
+ adverbs = ["quietly", "boldly", "slowly", "quickly", "cleanly", "strangely", "carefully"]
105
+ clauses = [
106
+ "when the loss falls",
107
+ "after the mask shifts",
108
+ "before the model answers",
109
+ "while the signal drifts",
110
+ "if the pattern repeats",
111
+ "because the tail is noisy",
112
+ ]
113
+ symbols = ["alpha", "beta", "gamma", "delta", "omega", "sigma"]
114
+
115
+ lines: List[str] = []
116
+ for _ in range(n_sentences):
117
+ t = rng.randrange(6)
118
+ if t == 0:
119
+ line = f"{rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} {rng.choice(adverbs)}."
120
+ elif t == 1:
121
+ line = f"{rng.choice(clauses)}, {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)}."
122
+ elif t == 2:
123
+ a, b = rng.sample(symbols, 2)
124
+ line = f"rule {a}: {rng.choice(objects)} -> {rng.choice(objects)}; rule {b}: {rng.choice(objects)} -> {rng.choice(objects)}."
125
+ elif t == 3:
126
+ line = f"the {rng.choice(objects)} {rng.choice(verbs)} the {rng.choice(objects)} {rng.choice(adverbs)}."
127
+ elif t == 4:
128
+ seq = " ".join(rng.choice(symbols) for _ in range(rng.randint(2, 7)))
129
+ line = f"sequence {seq} ends when {rng.choice(names)} {rng.choice(verbs)}."
130
+ else:
131
+ line = f"if {rng.choice(objects)} rise then {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} else wait."
132
+ lines.append(line)
133
+ return "\n".join(lines) + "\n"
134
+
135
+
136
+ class CharCorpus:
137
+ def __init__(self, text: str, block_size: int, device: str):
138
+ chars = sorted(set(text))
139
+ self.stoi = {ch: i for i, ch in enumerate(chars)}
140
+ self.itos = {i: ch for ch, i in self.stoi.items()}
141
+ self.vocab_size = len(chars)
142
+ self.block_size = block_size
143
+ self.device = device
144
+
145
+ data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
146
+ split = int(0.9 * len(data))
147
+ self.train_data = data[:split]
148
+ self.val_data = data[split:]
149
+
150
+ def get_batch(
151
+ self,
152
+ split: str,
153
+ batch_size: int,
154
+ generator: Optional[torch.Generator] = None,
155
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
156
+ data = self.train_data if split == "train" else self.val_data
157
+ max_start = len(data) - self.block_size - 1
158
+ if max_start <= 0:
159
+ raise ValueError("Corpus too small for block_size")
160
+ ix = torch.randint(max_start, (batch_size,), generator=generator)
161
+ x = torch.stack([data[i : i + self.block_size] for i in ix])
162
+ y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
163
+ return x.to(self.device), y.to(self.device)
164
+
165
+
166
+ def load_text(args: argparse.Namespace) -> str:
167
+ if args.text_path:
168
+ with open(args.text_path, "r", encoding="utf-8") as f:
169
+ return f.read()
170
+ return make_synthetic_corpus(args.synthetic_sentences, args.seed)
171
+
172
+
173
+ # -----------------------------
174
+ # Sparse Linear autograd
175
+ # -----------------------------
176
+
177
+ class MaskedLinearFunction(torch.autograd.Function):
178
+ @staticmethod
179
+ def forward( # type: ignore[override]
180
+ ctx,
181
+ x: torch.Tensor,
182
+ weight: torch.Tensor,
183
+ bias: Optional[torch.Tensor],
184
+ active_rows: torch.Tensor,
185
+ sparse_dx: bool,
186
+ ) -> torch.Tensor:
187
+ ctx.save_for_backward(x, weight, active_rows)
188
+ ctx.has_bias = bias is not None
189
+ ctx.sparse_dx = bool(sparse_dx)
190
+ return F.linear(x, weight, bias)
191
+
192
+ @staticmethod
193
+ def backward(ctx, grad_y: torch.Tensor): # type: ignore[override]
194
+ x, weight, active_rows = ctx.saved_tensors
195
+ sparse_dx = bool(ctx.sparse_dx)
196
+ has_bias = bool(ctx.has_bias)
197
+
198
+ x_shape = x.shape
199
+ x_flat = x.reshape(-1, x.shape[-1])
200
+ gy_flat = grad_y.reshape(-1, grad_y.shape[-1])
201
+
202
+ active_idx = torch.nonzero(active_rows, as_tuple=False).flatten()
203
+
204
+ grad_weight = torch.zeros_like(weight)
205
+ grad_bias = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if has_bias else None
206
+
207
+ if active_idx.numel() > 0:
208
+ gy_active = gy_flat[:, active_idx]
209
+ grad_weight[active_idx] = gy_active.transpose(0, 1) @ x_flat
210
+ if grad_bias is not None:
211
+ grad_bias[active_idx] = gy_active.sum(dim=0)
212
+
213
+ if sparse_dx:
214
+ grad_x_flat = gy_active @ weight[active_idx]
215
+ else:
216
+ grad_x_flat = gy_flat @ weight
217
+ else:
218
+ # This can happen when a global top-k mask selects no rows from a
219
+ # particular layer. Conservative full_dX still propagates through that
220
+ # layer; aggressive sparse_dX cuts it off for that layer.
221
+ if sparse_dx:
222
+ grad_x_flat = torch.zeros_like(x_flat)
223
+ else:
224
+ grad_x_flat = gy_flat @ weight
225
+
226
+ grad_x = grad_x_flat.reshape(x_shape)
227
+ return grad_x, grad_weight, grad_bias, None, None
228
+
229
+
230
+ class SparseLinear(nn.Linear):
231
+ """nn.Linear with an optional row-sparse backward pass."""
232
+
233
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
234
+ super().__init__(in_features, out_features, bias=bias)
235
+ self.sparse_enabled = False
236
+ self.sparse_dx = False
237
+ self.active_rows: Optional[torch.Tensor] = None
238
+
239
+ def set_sparse_backward(self, enabled: bool, active_rows: Optional[torch.Tensor], sparse_dx: bool) -> None:
240
+ self.sparse_enabled = bool(enabled)
241
+ self.sparse_dx = bool(sparse_dx)
242
+ self.active_rows = active_rows
243
+
244
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
245
+ if not self.sparse_enabled or self.active_rows is None:
246
+ return F.linear(x, self.weight, self.bias)
247
+ return MaskedLinearFunction.apply(x, self.weight, self.bias, self.active_rows, self.sparse_dx)
248
+
249
+
250
+ # -----------------------------
251
+ # Mini GPT
252
+ # -----------------------------
253
+
254
+ class CausalSelfAttention(nn.Module):
255
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
256
+ super().__init__()
257
+ assert n_embd % n_head == 0
258
+ self.n_head = n_head
259
+ self.head_dim = n_embd // n_head
260
+ self.c_attn = SparseLinear(n_embd, 3 * n_embd)
261
+ self.c_proj = SparseLinear(n_embd, n_embd)
262
+ self.dropout = nn.Dropout(dropout)
263
+ self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
264
+
265
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
266
+ B, T, C = x.shape
267
+ qkv = self.c_attn(x)
268
+ q, k, v = qkv.split(C, dim=2)
269
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
270
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
271
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
272
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
273
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
274
+ att = F.softmax(att, dim=-1)
275
+ att = self.dropout(att)
276
+ y = att @ v
277
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
278
+ return self.c_proj(y)
279
+
280
+
281
+ class FeedForward(nn.Module):
282
+ def __init__(self, n_embd: int, dropout: float):
283
+ super().__init__()
284
+ self.c_fc = SparseLinear(n_embd, 4 * n_embd)
285
+ self.c_proj = SparseLinear(4 * n_embd, n_embd)
286
+ self.dropout = nn.Dropout(dropout)
287
+
288
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
289
+ return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
290
+
291
+
292
+ class Block(nn.Module):
293
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
294
+ super().__init__()
295
+ self.ln1 = nn.LayerNorm(n_embd)
296
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
297
+ self.ln2 = nn.LayerNorm(n_embd)
298
+ self.mlp = FeedForward(n_embd, dropout)
299
+
300
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
301
+ x = x + self.attn(self.ln1(x))
302
+ x = x + self.mlp(self.ln2(x))
303
+ return x
304
+
305
+
306
+ class MiniGPT(nn.Module):
307
+ def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float):
308
+ super().__init__()
309
+ self.block_size = block_size
310
+ self.tok_emb = nn.Embedding(vocab_size, n_embd)
311
+ self.pos_emb = nn.Embedding(block_size, n_embd)
312
+ self.drop = nn.Dropout(dropout)
313
+ self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
314
+ self.ln_f = nn.LayerNorm(n_embd)
315
+ self.lm_head = SparseLinear(n_embd, vocab_size)
316
+
317
+ def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
318
+ B, T = idx.shape
319
+ pos = torch.arange(T, device=idx.device)
320
+ x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
321
+ x = self.drop(x)
322
+ x = self.blocks(x)
323
+ x = self.ln_f(x)
324
+ logits = self.lm_head(x)
325
+ loss = None
326
+ if targets is not None:
327
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
328
+ return logits, loss
329
+
330
+
331
+ def named_sparse_linear_modules(model: nn.Module) -> List[Tuple[str, SparseLinear]]:
332
+ return [(name, m) for name, m in model.named_modules() if isinstance(m, SparseLinear)]
333
+
334
+
335
+ def parameter_fractions(model: nn.Module) -> Tuple[int, int, float]:
336
+ total = sum(p.numel() for p in model.parameters())
337
+ linear = 0
338
+ for _, m in named_sparse_linear_modules(model):
339
+ linear += m.weight.numel()
340
+ if m.bias is not None:
341
+ linear += m.bias.numel()
342
+ return total, linear, linear / max(1, total)
343
+
344
+
345
+ def configure_sparse_linears(
346
+ model: nn.Module,
347
+ masker: Optional["RowMasker"],
348
+ enabled: bool,
349
+ backward_mode: Optional[str],
350
+ ) -> None:
351
+ sparse_dx = backward_mode == "sparse_dW_sparse_dX"
352
+ for _, m in named_sparse_linear_modules(model):
353
+ active = masker.row_mask_for(m) if masker is not None else None
354
+ m.set_sparse_backward(enabled=enabled, active_rows=active, sparse_dx=sparse_dx)
355
+
356
+
357
+ # -----------------------------
358
+ # Mask selector
359
+ # -----------------------------
360
+
361
+ class RowMasker:
362
+ def __init__(
363
+ self,
364
+ model: nn.Module,
365
+ policy: Policy,
366
+ active_fraction: float,
367
+ explore_fraction: float,
368
+ mass_beta: float,
369
+ unobserved_decay: float,
370
+ warmup_steps: int,
371
+ ucb_alpha: float,
372
+ mass_init: float,
373
+ device: str,
374
+ ):
375
+ self.model = model
376
+ self.policy = policy
377
+ self.active_fraction = active_fraction
378
+ self.explore_fraction = explore_fraction
379
+ self.mass_beta = mass_beta
380
+ self.unobserved_decay = unobserved_decay
381
+ self.warmup_steps = warmup_steps
382
+ self.ucb_alpha = ucb_alpha
383
+ self.mass_init = mass_init
384
+ self.device = device
385
+ self.step_index = 0
386
+
387
+ self.linear_modules = [m for _, m in named_sparse_linear_modules(model)]
388
+ self.module_to_ids: Dict[SparseLinear, torch.Tensor] = {}
389
+ ids = []
390
+ offset = 0
391
+ for m in self.linear_modules:
392
+ n = m.weight.shape[0]
393
+ block_ids = torch.arange(offset, offset + n, device=device)
394
+ self.module_to_ids[m] = block_ids
395
+ ids.append(block_ids)
396
+ offset += n
397
+ self.n_blocks = offset
398
+
399
+ self.predicted_mass = torch.full((self.n_blocks,), mass_init, device=device)
400
+ self.last_full_mass = torch.full((self.n_blocks,), mass_init, device=device)
401
+ self.observed_count = torch.zeros(self.n_blocks, device=device)
402
+ self.global_mass_ema = torch.tensor(max(mass_init, 1e-6), device=device)
403
+
404
+ self.prev_active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device)
405
+ self.active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device)
406
+ self.row_masks: Dict[SparseLinear, torch.Tensor] = {
407
+ m: torch.zeros(m.weight.shape[0], dtype=torch.bool, device=device) for m in self.linear_modules
408
+ }
409
+
410
+ def _topk_mask(self, values: torch.Tensor, fraction: float) -> torch.Tensor:
411
+ k = max(1, int(fraction * values.numel()))
412
+ mask = torch.zeros_like(values, dtype=torch.bool)
413
+ noisy = values + 1e-9 * torch.rand_like(values)
414
+ mask[torch.topk(noisy, k=k).indices] = True
415
+ return mask
416
+
417
+ @staticmethod
418
+ def _jaccard(a: torch.Tensor, b: torch.Tensor) -> float:
419
+ inter = (a & b).sum().float()
420
+ union = (a | b).sum().float()
421
+ return float((inter / torch.clamp(union, min=1.0)).item())
422
+
423
+ def _set_active(self, active: torch.Tensor) -> None:
424
+ self.active = active
425
+ self.row_masks = {}
426
+ for m, ids in self.module_to_ids.items():
427
+ self.row_masks[m] = active[ids]
428
+
429
+ def _sample_exploit_explore(self, scores: torch.Tensor) -> torch.Tensor:
430
+ n = self.n_blocks
431
+ k_total = max(1, int(self.active_fraction * n))
432
+ k_explore = min(k_total, max(0, int(self.explore_fraction * k_total)))
433
+ k_exploit = k_total - k_explore
434
+ active = torch.zeros(n, dtype=torch.bool, device=self.device)
435
+
436
+ if k_exploit > 0:
437
+ active[torch.topk(scores + 1e-9 * torch.rand_like(scores), k=k_exploit).indices] = True
438
+ if k_explore > 0:
439
+ remaining = torch.nonzero(~active, as_tuple=False).flatten()
440
+ pick = remaining[torch.randperm(remaining.numel(), device=self.device)[:k_explore]]
441
+ active[pick] = True
442
+ return active
443
+
444
+ def choose_pre_backward(self, step: int) -> None:
445
+ self.step_index = step
446
+ if step < self.warmup_steps:
447
+ self._set_active(torch.ones(self.n_blocks, dtype=torch.bool, device=self.device))
448
+ return
449
+
450
+ if self.policy == "oracle_current":
451
+ # Oracle cannot choose until the dense audit gradient is known.
452
+ self._set_active(torch.zeros(self.n_blocks, dtype=torch.bool, device=self.device))
453
+ return
454
+
455
+ if self.policy == "random":
456
+ self._set_active(self._sample_exploit_explore(torch.rand(self.n_blocks, device=self.device)))
457
+ return
458
+
459
+ if self.policy == "stale_current":
460
+ self._set_active(self._topk_mask(self.last_full_mass, self.active_fraction))
461
+ return
462
+
463
+ if self.policy == "predicted_magnitude":
464
+ self._set_active(self._sample_exploit_explore(self.predicted_mass))
465
+ return
466
+
467
+ if self.policy == "ucb_magnitude":
468
+ t = max(1, step - self.warmup_steps + 1)
469
+ log_term = torch.log(torch.tensor(float(t + 2), device=self.device))
470
+ bonus_scale = torch.clamp(self.global_mass_ema, min=1e-8)
471
+ bonus = self.ucb_alpha * bonus_scale * torch.sqrt(log_term / (self.observed_count + 1.0))
472
+ self._set_active(self._sample_exploit_explore(self.predicted_mass + bonus))
473
+ return
474
+
475
+ raise ValueError(f"Unknown policy: {self.policy}")
476
+
477
+ @torch.no_grad()
478
+ def current_gradient_mass_from_grads(self) -> torch.Tensor:
479
+ mass = torch.zeros(self.n_blocks, device=self.device)
480
+ for m, ids in self.module_to_ids.items():
481
+ if m.weight.grad is None:
482
+ continue
483
+ row_sq = m.weight.grad.square().sum(dim=1)
484
+ if m.bias is not None and m.bias.grad is not None:
485
+ row_sq = row_sq + m.bias.grad.square()
486
+ mass[ids] = torch.sqrt(row_sq + 1e-30)
487
+ return mass
488
+
489
+ @torch.no_grad()
490
+ @torch.no_grad()
491
+ def update_predictor_from_observed_mass(self, mass: torch.Tensor, observed: Optional[torch.Tensor] = None) -> Dict[str, float]:
492
+ """Update EMA statistics only for observed rows.
493
+
494
+ After warmup, sparse backward only gives trustworthy gradients for active
495
+ rows, so only those rows are allowed to update predicted_mass.
496
+ """
497
+ if observed is None:
498
+ observed = self.active
499
+
500
+ new_active = observed & (self.observed_count == 0)
501
+ self.predicted_mass.mul_(self.unobserved_decay)
502
+
503
+ if bool(observed.any().item()):
504
+ obs_mass = mass[observed]
505
+ first_seen = self.observed_count[observed] == 0
506
+ ema_mass = self.mass_beta * self.predicted_mass[observed] + (1.0 - self.mass_beta) * obs_mass
507
+ self.predicted_mass[observed] = torch.where(first_seen, obs_mass, ema_mass)
508
+ self.observed_count[observed] += 1.0
509
+ self.global_mass_ema = self.mass_beta * self.global_mass_ema + (1.0 - self.mass_beta) * obs_mass.mean()
510
+
511
+ stability = self._jaccard(self.active, self.prev_active)
512
+ self.prev_active = self.active.clone()
513
+
514
+ return {
515
+ "stability": stability,
516
+ "active_fraction_real": float(self.active.float().mean().item()),
517
+ "coverage": float((self.observed_count > 0).float().mean().item()),
518
+ "avg_obs_count": float(self.observed_count.mean().item()),
519
+ "new_active_fraction": float(new_active.float().mean().item()),
520
+ }
521
+
522
+ @torch.no_grad()
523
+ def audit_metrics_from_mass(self, mass: torch.Tensor) -> Dict[str, float]:
524
+ """Compute dense-audit metrics without updating the practical selector."""
525
+ active = self.active
526
+ true_sq = mass.square().sum()
527
+ approx_sq = mass[active].square().sum()
528
+ cosine = float((torch.sqrt(approx_sq + 1e-30) / torch.sqrt(true_sq + 1e-30)).item())
529
+
530
+ oracle_mask = self._topk_mask(mass, self.active_fraction)
531
+ jacc = self._jaccard(active, oracle_mask)
532
+
533
+ k20 = max(1, int(0.2 * self.n_blocks))
534
+ sorted_mass = torch.sort(mass, descending=True).values
535
+ top20_mass = float((sorted_mass[:k20].sum() / (sorted_mass.sum() + 1e-12)).item())
536
+
537
+ return {
538
+ "cosine": cosine,
539
+ "norm_ratio": cosine,
540
+ "top20_mass": top20_mass,
541
+ "jacc_oracle": jacc,
542
+ }
543
+
544
+ def audit_and_update_from_mass(self, step: int, mass: torch.Tensor) -> Dict[str, float]:
545
+ if step < self.warmup_steps:
546
+ active = torch.ones(self.n_blocks, dtype=torch.bool, device=self.device)
547
+ self._set_active(active)
548
+ elif self.policy == "oracle_current":
549
+ active = self._topk_mask(mass, self.active_fraction)
550
+ self._set_active(active)
551
+ else:
552
+ active = self.active
553
+
554
+ true_sq = mass.square().sum()
555
+ approx_sq = mass[active].square().sum()
556
+ cosine = float((torch.sqrt(approx_sq + 1e-30) / torch.sqrt(true_sq + 1e-30)).item())
557
+
558
+ oracle_mask = self._topk_mask(mass, self.active_fraction)
559
+ jacc = self._jaccard(active, oracle_mask)
560
+ stability = self._jaccard(active, self.prev_active)
561
+ self.prev_active = active.clone()
562
+
563
+ k20 = max(1, int(0.2 * self.n_blocks))
564
+ sorted_mass = torch.sort(mass, descending=True).values
565
+ top20_mass = float((sorted_mass[:k20].sum() / (sorted_mass.sum() + 1e-12)).item())
566
+
567
+ new_active = active & (self.observed_count == 0)
568
+
569
+ # Practical rule: update predicted statistics only for active/observed rows.
570
+ self.predicted_mass.mul_(self.unobserved_decay)
571
+ observed = active
572
+ if bool(observed.any().item()):
573
+ obs_mass = mass[observed]
574
+ first_seen = self.observed_count[observed] == 0
575
+ ema_mass = self.mass_beta * self.predicted_mass[observed] + (1.0 - self.mass_beta) * obs_mass
576
+ self.predicted_mass[observed] = torch.where(first_seen, obs_mass, ema_mass)
577
+ self.observed_count[observed] += 1.0
578
+ self.global_mass_ema = self.mass_beta * self.global_mass_ema + (1.0 - self.mass_beta) * obs_mass.mean()
579
+
580
+ # Dense audit signal; only stale_current is allowed to use this for selection.
581
+ self.last_full_mass = mass.detach().clone()
582
+
583
+ return {
584
+ "cosine": cosine,
585
+ "norm_ratio": cosine,
586
+ "top20_mass": top20_mass,
587
+ "jacc_oracle": jacc,
588
+ "stability": stability,
589
+ "active_fraction_real": float(active.float().mean().item()),
590
+ "coverage": float((self.observed_count > 0).float().mean().item()),
591
+ "avg_obs_count": float(self.observed_count.mean().item()),
592
+ "new_active_fraction": float(new_active.float().mean().item()),
593
+ }
594
+
595
+ def row_mask_for(self, module: SparseLinear) -> Optional[torch.Tensor]:
596
+ return self.row_masks.get(module)
597
+
598
+
599
+ # -----------------------------
600
+ # Masked Adam
601
+ # -----------------------------
602
+
603
+ class MaskedAdam:
604
+ def __init__(
605
+ self,
606
+ model: nn.Module,
607
+ masker: Optional[RowMasker],
608
+ lr: float,
609
+ betas=(0.9, 0.95),
610
+ eps=1e-8,
611
+ weight_decay=0.0,
612
+ freeze_non_linear_when_sparse: bool = False,
613
+ ):
614
+ self.model = model
615
+ self.masker = masker
616
+ self.lr = lr
617
+ self.beta1, self.beta2 = betas
618
+ self.eps = eps
619
+ self.weight_decay = weight_decay
620
+ self.freeze_non_linear_when_sparse = freeze_non_linear_when_sparse
621
+ self.state: Dict[nn.Parameter, Dict[str, torch.Tensor]] = {}
622
+ self.linear_param: Dict[nn.Parameter, Tuple[SparseLinear, str]] = {}
623
+ for _, m in named_sparse_linear_modules(model):
624
+ self.linear_param[m.weight] = (m, "weight")
625
+ if m.bias is not None:
626
+ self.linear_param[m.bias] = (m, "bias")
627
+
628
+ def zero_grad(self) -> None:
629
+ for p in self.model.parameters():
630
+ p.grad = None
631
+
632
+ @torch.no_grad()
633
+ def step(self) -> None:
634
+ for p in self.model.parameters():
635
+ if p.grad is None:
636
+ continue
637
+ if self.masker is not None and self.freeze_non_linear_when_sparse and p not in self.linear_param:
638
+ continue
639
+
640
+ if p not in self.state:
641
+ self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
642
+ m = self.state[p]["m"]
643
+ v = self.state[p]["v"]
644
+ g = p.grad
645
+ if self.weight_decay:
646
+ g = g.add(p, alpha=self.weight_decay)
647
+
648
+ row_mask = None
649
+ if self.masker is not None and p in self.linear_param:
650
+ module, kind = self.linear_param[p]
651
+ base = self.masker.row_mask_for(module)
652
+ if base is not None:
653
+ row_mask = base.view(-1, *([1] * (p.ndim - 1))) if kind == "weight" else base
654
+
655
+ if row_mask is None:
656
+ m.mul_(self.beta1).add_(g, alpha=1.0 - self.beta1)
657
+ v.mul_(self.beta2).addcmul_(g, g, value=1.0 - self.beta2)
658
+ p.add_(m / (torch.sqrt(v) + self.eps), alpha=-self.lr)
659
+ else:
660
+ # MPS can mis-handle expanded boolean masks for row-wise assignment
661
+ # (e.g. reporting nonsense out-of-bounds indices). Use explicit
662
+ # row indices and index_copy_ instead. This also avoids materializing
663
+ # a full expanded mask for weight matrices.
664
+ active_rows = row_mask.reshape(-1).nonzero(as_tuple=False).flatten()
665
+ if active_rows.numel() == 0:
666
+ continue
667
+
668
+ m_rows = m.index_select(0, active_rows)
669
+ v_rows = v.index_select(0, active_rows)
670
+ g_rows = g.index_select(0, active_rows)
671
+
672
+ new_m_rows = self.beta1 * m_rows + (1.0 - self.beta1) * g_rows
673
+ new_v_rows = self.beta2 * v_rows + (1.0 - self.beta2) * g_rows * g_rows
674
+ update_rows = new_m_rows / (torch.sqrt(new_v_rows) + self.eps)
675
+ p_rows = p.index_select(0, active_rows) - self.lr * update_rows
676
+
677
+ m.index_copy_(0, active_rows, new_m_rows)
678
+ v.index_copy_(0, active_rows, new_v_rows)
679
+ p.index_copy_(0, active_rows, p_rows)
680
+
681
+
682
+ # -----------------------------
683
+ # Training utilities
684
+ # -----------------------------
685
+
686
+ @torch.no_grad()
687
+ def estimate_loss(model: nn.Module, corpus: CharCorpus, batch_size: int, eval_iters: int, seed: int) -> Dict[str, float]:
688
+ model.eval()
689
+ configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
690
+ out = {}
691
+ for split in ["train", "val"]:
692
+ losses = []
693
+ gen = make_cpu_generator(seed + (0 if split == "train" else 100000))
694
+ for _ in range(eval_iters):
695
+ x, y = corpus.get_batch(split, batch_size, generator=gen)
696
+ _, loss = model(x, y)
697
+ losses.append(float(loss.item()))
698
+ out[split] = sum(losses) / len(losses)
699
+ model.train()
700
+ return out
701
+
702
+
703
+ def dense_audit_pass(model: nn.Module, corpus_batch: Tuple[torch.Tensor, torch.Tensor], opt: MaskedAdam, masker: RowMasker) -> torch.Tensor:
704
+ x, y = corpus_batch
705
+ configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
706
+ opt.zero_grad()
707
+ _, audit_loss = model(x, y)
708
+ audit_loss.backward()
709
+ mass = masker.current_gradient_mass_from_grads()
710
+ opt.zero_grad()
711
+ return mass
712
+
713
+
714
+ def sparse_training_backward(
715
+ model: nn.Module,
716
+ corpus_batch: Tuple[torch.Tensor, torch.Tensor],
717
+ opt: MaskedAdam,
718
+ masker: Optional[RowMasker],
719
+ backward_mode: Optional[BackwardMode],
720
+ ) -> float:
721
+ x, y = corpus_batch
722
+ opt.zero_grad()
723
+
724
+ if masker is None or backward_mode is None or backward_mode == "masked_optimizer":
725
+ configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
726
+ else:
727
+ configure_sparse_linears(model, masker=masker, enabled=True, backward_mode=backward_mode)
728
+
729
+ _, loss = model(x, y)
730
+ loss.backward()
731
+ configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
732
+ return float(loss.item())
733
+
734
+
735
+ def train_run(
736
+ corpus: CharCorpus,
737
+ args: argparse.Namespace,
738
+ policy: Optional[Policy],
739
+ backward_mode: Optional[BackwardMode],
740
+ active_fraction: float,
741
+ warmup_steps: int,
742
+ explore_fraction: float,
743
+ seed_offset: int,
744
+ ) -> Dict[str, float | str]:
745
+ # Same model initialization and same minibatch sequence for every run by default.
746
+ set_seed(args.seed + (seed_offset if args.unpaired_seeds else 0))
747
+ data_gen = make_cpu_generator(args.seed + 12345)
748
+
749
+ dev = corpus.device
750
+ model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev)
751
+
752
+ masker = None
753
+ if policy is not None:
754
+ masker = RowMasker(
755
+ model=model,
756
+ policy=policy,
757
+ active_fraction=active_fraction,
758
+ explore_fraction=explore_fraction,
759
+ mass_beta=args.mass_beta,
760
+ unobserved_decay=args.unobserved_decay,
761
+ warmup_steps=warmup_steps,
762
+ ucb_alpha=args.ucb_alpha,
763
+ mass_init=args.mass_init,
764
+ device=dev,
765
+ )
766
+ opt = MaskedAdam(
767
+ model,
768
+ masker,
769
+ lr=args.lr,
770
+ weight_decay=args.weight_decay,
771
+ freeze_non_linear_when_sparse=args.freeze_non_linear_when_sparse,
772
+ )
773
+
774
+ sums = {
775
+ "cosine": 0.0,
776
+ "norm_ratio": 0.0,
777
+ "top20_mass": 0.0,
778
+ "jacc_oracle": 0.0,
779
+ "stability": 0.0,
780
+ "active_fraction_real": 0.0,
781
+ "coverage": 0.0,
782
+ "avg_obs_count": 0.0,
783
+ "new_active_fraction": 0.0,
784
+ }
785
+ counts = {k: 0 for k in sums}
786
+
787
+ def add_metrics(metrics: Dict[str, float]) -> None:
788
+ for k, v in metrics.items():
789
+ if k in sums:
790
+ sums[k] += float(v)
791
+ counts[k] += 1
792
+
793
+ for step in range(args.steps):
794
+ batch = corpus.get_batch("train", args.batch_size, generator=data_gen)
795
+
796
+ if masker is None:
797
+ loss_value = sparse_training_backward(model, batch, opt, masker=None, backward_mode=None)
798
+ opt.step()
799
+ else:
800
+ if step < warmup_steps:
801
+ # Dense bootstrap. Every row is active and every row updates the predictor.
802
+ masker._set_active(torch.ones(masker.n_blocks, dtype=torch.bool, device=dev))
803
+ loss_value = sparse_training_backward(model, batch, opt, masker=masker, backward_mode="masked_optimizer")
804
+ full_mass = masker.current_gradient_mass_from_grads()
805
+ masker.last_full_mass = full_mass.detach().clone()
806
+ add_metrics(masker.audit_metrics_from_mass(full_mass))
807
+ add_metrics(masker.update_predictor_from_observed_mass(full_mass, observed=masker.active))
808
+ opt.step()
809
+ else:
810
+ masker.choose_pre_backward(step)
811
+
812
+ if policy == "oracle_current":
813
+ # Explicit upper bound. Oracle necessarily computes dense gradients to choose rows.
814
+ full_mass = dense_audit_pass(model, batch, opt, masker)
815
+ masker._set_active(masker._topk_mask(full_mass, active_fraction))
816
+ masker.last_full_mass = full_mass.detach().clone()
817
+ add_metrics(masker.audit_metrics_from_mass(full_mass))
818
+ elif args.audit_every > 0 and ((step - warmup_steps) % args.audit_every == 0):
819
+ # Measurement only. Do not update predicted_magnitude/ucb/random with this dense mass.
820
+ full_mass = dense_audit_pass(model, batch, opt, masker)
821
+ add_metrics(masker.audit_metrics_from_mass(full_mass))
822
+ if policy == "stale_current":
823
+ masker.last_full_mass = full_mass.detach().clone()
824
+
825
+ loss_value = sparse_training_backward(model, batch, opt, masker=masker, backward_mode=backward_mode)
826
+
827
+ # Practical selector update: only active rows were observed by the training backward pass.
828
+ observed_mass = masker.current_gradient_mass_from_grads()
829
+ add_metrics(masker.update_predictor_from_observed_mass(observed_mass, observed=masker.active))
830
+ opt.step()
831
+
832
+ if args.verbose and (step % args.eval_interval == 0 or step == args.steps - 1):
833
+ losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters, seed=args.seed + 555)
834
+ name = "dense" if policy is None else f"{policy}/{backward_mode}"
835
+ print(
836
+ f"{name:38s} step={step:5d} warm={warmup_steps:4d} explore={explore_fraction:.2f} "
837
+ f"loss={loss_value:.4f} train={losses['train']:.4f} val={losses['val']:.4f}"
838
+ )
839
+
840
+ losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters, seed=args.seed + 999)
841
+ row: Dict[str, float | str] = {
842
+ "run": "dense_baseline" if policy is None else policy,
843
+ "mode": "dense" if backward_mode is None else backward_mode,
844
+ "target_active": 1.0 if policy is None else active_fraction,
845
+ "warmup": warmup_steps,
846
+ "explore": explore_fraction if policy is not None else 0.0,
847
+ "train_loss": losses["train"],
848
+ "val_loss": losses["val"],
849
+ }
850
+ if masker is None:
851
+ row.update({
852
+ "cosine": float("nan"),
853
+ "norm_ratio": float("nan"),
854
+ "top20_mass": float("nan"),
855
+ "jacc_oracle": float("nan"),
856
+ "stability": float("nan"),
857
+ "active_fraction_real": 1.0,
858
+ "coverage": float("nan"),
859
+ "avg_obs_count": float("nan"),
860
+ "new_active_fraction": float("nan"),
861
+ })
862
+ else:
863
+ for k in sums:
864
+ row[k] = (sums[k] / counts[k]) if counts[k] > 0 else float("nan")
865
+ return row
866
+
867
+ def print_summary(rows: List[Dict[str, float | str]]) -> None:
868
+ print("\nSummary")
869
+ header = (
870
+ f"{'run':>22s} {'mode':>19s} {'target':>7s} {'actual':>7s} {'warm':>5s} {'expl':>5s} "
871
+ f"{'val':>8s} {'train':>8s} {'cos':>7s} {'top20':>7s} {'jacc':>7s} "
872
+ f"{'stable':>7s} {'cover':>7s} {'new':>7s}"
873
+ )
874
+ print(header)
875
+ print("-" * len(header))
876
+ for r in rows:
877
+ print(
878
+ f"{str(r['run']):>22s} "
879
+ f"{str(r['mode']):>19s} "
880
+ f"{float(r['target_active']):7.3f} "
881
+ f"{float(r['active_fraction_real']):7.3f} "
882
+ f"{int(float(r['warmup'])):5d} "
883
+ f"{float(r['explore']):5.2f} "
884
+ f"{float(r['val_loss']):8.4f} "
885
+ f"{float(r['train_loss']):8.4f} "
886
+ f"{float(r['cosine']):7.3f} "
887
+ f"{float(r['top20_mass']):7.3f} "
888
+ f"{float(r['jacc_oracle']):7.3f} "
889
+ f"{float(r['stability']):7.3f} "
890
+ f"{float(r['coverage']):7.3f} "
891
+ f"{float(r['new_active_fraction']):7.3f}"
892
+ )
893
+
894
+
895
+ def parse_args() -> argparse.Namespace:
896
+ p = argparse.ArgumentParser()
897
+ p.add_argument("--text_path", type=str, default=None)
898
+ p.add_argument("--synthetic_sentences", type=int, default=12000)
899
+ p.add_argument("--steps", type=int, default=1000)
900
+ p.add_argument("--quick", action="store_true")
901
+ p.add_argument("--batch_size", type=int, default=32)
902
+ p.add_argument("--block_size", type=int, default=64)
903
+ p.add_argument("--n_layer", type=int, default=2)
904
+ p.add_argument("--n_head", type=int, default=4)
905
+ p.add_argument("--n_embd", type=int, default=64)
906
+ p.add_argument("--dropout", type=float, default=0.0)
907
+ p.add_argument("--lr", type=float, default=3e-4)
908
+ p.add_argument("--weight_decay", type=float, default=0.0)
909
+ p.add_argument("--active_fractions", type=float, nargs="+", default=[0.05, 0.02])
910
+ p.add_argument("--policies", type=str, nargs="+", default=["oracle_current", "predicted_magnitude", "random"])
911
+ p.add_argument(
912
+ "--backward_modes",
913
+ type=str,
914
+ nargs="+",
915
+ default=["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"],
916
+ )
917
+ p.add_argument("--explore_fractions", type=float, nargs="+", default=[0.0])
918
+ p.add_argument("--warmup_steps_list", type=int, nargs="+", default=[5])
919
+ p.add_argument("--mass_beta", type=float, default=0.95)
920
+ p.add_argument("--unobserved_decay", type=float, default=1.0)
921
+ p.add_argument("--mass_init", type=float, default=0.0)
922
+ p.add_argument("--ucb_alpha", type=float, default=1.0)
923
+ p.add_argument("--freeze_non_linear_when_sparse", action="store_true")
924
+ p.add_argument("--eval_interval", type=int, default=200)
925
+ p.add_argument("--eval_iters", type=int, default=20)
926
+ p.add_argument("--seed", type=int, default=7)
927
+ p.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "cuda", "mps"])
928
+ p.add_argument("--audit_every", type=int, default=0, help="Dense audit interval after warmup. 0 disables audits except oracle_current.")
929
+ p.add_argument("--unpaired_seeds", action="store_true", help="Use different init seeds per run instead of paired seeds.")
930
+ p.add_argument("--verbose", action="store_true")
931
+ return p.parse_args()
932
+
933
+
934
+ def main() -> None:
935
+ args = parse_args()
936
+ if args.quick:
937
+ args.steps = 40
938
+ args.eval_iters = 2
939
+ args.batch_size = 8
940
+ args.block_size = 32
941
+ args.n_layer = 1
942
+ args.n_embd = 32
943
+ args.n_head = 4
944
+ args.synthetic_sentences = 1200
945
+ args.active_fractions = [0.05]
946
+ args.policies = ["predicted_magnitude", "random"]
947
+ args.backward_modes = ["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
948
+ args.explore_fractions = [0.0]
949
+ args.warmup_steps_list = [5]
950
+ args.audit_every = 10
951
+
952
+ valid_policies = {"predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"}
953
+ valid_modes = {"masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"}
954
+ for pol in args.policies:
955
+ if pol not in valid_policies:
956
+ raise ValueError(f"Unknown policy {pol!r}. Valid policies: {sorted(valid_policies)}")
957
+ for mode in args.backward_modes:
958
+ if mode not in valid_modes:
959
+ raise ValueError(f"Unknown backward mode {mode!r}. Valid modes: {sorted(valid_modes)}")
960
+
961
+ set_seed(args.seed)
962
+ dev = args.device if args.device != "auto" else default_device()
963
+ print(f"device={dev}")
964
+ corpus = CharCorpus(load_text(args), args.block_size, dev)
965
+ print(f"vocab_size={corpus.vocab_size} train_tokens={len(corpus.train_data)} val_tokens={len(corpus.val_data)}")
966
+ print(f"policies={args.policies}")
967
+ print(f"backward_modes={args.backward_modes}")
968
+ print(f"active_fractions={args.active_fractions}")
969
+ print(f"warmup_steps_list={args.warmup_steps_list} explore_fractions={args.explore_fractions}")
970
+ print(f"mass_init={args.mass_init} mass_beta={args.mass_beta} ucb_alpha={args.ucb_alpha}")
971
+ print(f"paired_seeds={not args.unpaired_seeds}")
972
+ print(f"audit_every={args.audit_every} (0 means no dense audit after warmup, except oracle_current)")
973
+
974
+ tmp_model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev)
975
+ total_params, linear_params, linear_frac = parameter_fractions(tmp_model)
976
+ del tmp_model
977
+ print(f"params total={total_params} linear={linear_params} linear_fraction={linear_frac:.3f}")
978
+ if args.freeze_non_linear_when_sparse:
979
+ print("freeze_non_linear_when_sparse=True: embeddings/layernorm/etc. are frozen in sparse runs")
980
+ else:
981
+ print("freeze_non_linear_when_sparse=False: non-Linear params are still updated densely")
982
+
983
+ if args.dropout != 0.0:
984
+ print("warning: dropout is nonzero; dense audit and sparse training passes may see different dropout masks")
985
+
986
+ rows: List[Dict[str, float | str]] = []
987
+ print("\nRunning dense baseline")
988
+ rows.append(
989
+ train_run(
990
+ corpus,
991
+ args,
992
+ policy=None,
993
+ backward_mode=None,
994
+ active_fraction=1.0,
995
+ warmup_steps=0,
996
+ explore_fraction=0.0,
997
+ seed_offset=0,
998
+ )
999
+ )
1000
+
1001
+ seed_offset = 100
1002
+ for mode in args.backward_modes:
1003
+ for af in args.active_fractions:
1004
+ for pol in args.policies:
1005
+ explore_values = args.explore_fractions if pol in {"predicted_magnitude", "ucb_magnitude"} else [0.0]
1006
+ for warmup in args.warmup_steps_list:
1007
+ for explore in explore_values:
1008
+ print(
1009
+ f"\nRunning mode={mode}, policy={pol}, "
1010
+ f"active_fraction={af:.3f}, warmup={warmup}, explore={explore:.2f}"
1011
+ )
1012
+ rows.append(
1013
+ train_run(
1014
+ corpus,
1015
+ args,
1016
+ policy=pol, # type: ignore[arg-type]
1017
+ backward_mode=mode, # type: ignore[arg-type]
1018
+ active_fraction=af,
1019
+ warmup_steps=warmup,
1020
+ explore_fraction=explore,
1021
+ seed_offset=seed_offset,
1022
+ )
1023
+ )
1024
+ seed_offset += 1
1025
+
1026
+ print_summary(rows)
1027
+
1028
+ print("\nNotes")
1029
+ print(" masked_optimizer is the v7-style dense-backward simulation control.")
1030
+ print(" sparse_dW_full_dX uses custom Linear backward: sparse weight/bias grads, full input gradient.")
1031
+ print(" sparse_dW_sparse_dX uses custom Linear backward: sparse weight/bias grads and sparse input gradient.")
1032
+ print(" oracle_current uses dense audit gradients to choose rows; it is an upper bound.")
1033
+ print(" predicted_magnitude uses EMA mass from active/observed rows only.")
1034
+ print(" random is the sparse-support control.")
1035
+ print(" v9 does not compute dense audit gradients after warmup unless --audit_every > 0, except oracle_current.")
1036
+ print(" predicted_magnitude updates EMA statistics only from active rows observed by the training backward pass.")
1037
+ print(" cosine/top20/jacc are nan when --audit_every 0 because no dense reference gradient is computed.")
1038
+ print(" This is still not a wall-clock benchmark: PyTorch indexing may not accelerate on CPU/MPS without a custom Metal kernel.")
1039
+
1040
+
1041
+ if __name__ == "__main__":
1042
+ main()
experiments/surprise_topk_gradient_prototype-v2.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Surprise Top-K Gradient Prototype, v2
3
+
4
+ Goal
5
+ ----
6
+ Test the hypothesis:
7
+
8
+ gradient_t ≈ predicted_gradient_t + sparse_surprising_residual_t
9
+
10
+ This version fixes two problems from v1:
11
+ 1. The baseline now actually learns the toy task, using Adam instead of raw SGD.
12
+ 2. The surprise method is tested as an approximate gradient passed into Adam,
13
+ rather than as a hand-written SGD update with unstable error-feedback buffers.
14
+
15
+ Important caveat
16
+ ----------------
17
+ This still computes the full gradient every step. That is intentional. This is a
18
+ hypothesis test: does a sparse "surprising residual" preserve the useful update
19
+ signal? If yes, a later version can try to skip real backward computation.
20
+
21
+ Run
22
+ ---
23
+ python3 surprise_topk_gradient_prototype.py
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import math
29
+ import random
30
+ from dataclasses import dataclass
31
+ from typing import Dict, List, Tuple
32
+
33
+ import torch
34
+ import torch.nn as nn
35
+ import torch.nn.functional as F
36
+
37
+
38
+ SEED = 7
39
+ random.seed(SEED)
40
+ torch.manual_seed(SEED)
41
+
42
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
43
+
44
+
45
+ # -----------------------------
46
+ # Toy data: 2-class spiral
47
+ # -----------------------------
48
+
49
+ def make_spiral(n_per_class: int = 1024, noise: float = 0.12) -> Tuple[torch.Tensor, torch.Tensor]:
50
+ xs = []
51
+ ys = []
52
+
53
+ for class_id in range(2):
54
+ r = torch.linspace(0.0, 1.0, n_per_class)
55
+ theta = class_id * math.pi + r * 4.0 * math.pi
56
+ theta = theta + torch.randn(n_per_class) * noise
57
+
58
+ x = torch.stack([r * torch.sin(theta), r * torch.cos(theta)], dim=1)
59
+ y = torch.full((n_per_class,), class_id, dtype=torch.long)
60
+
61
+ xs.append(x)
62
+ ys.append(y)
63
+
64
+ X = torch.cat(xs, dim=0)
65
+ Y = torch.cat(ys, dim=0)
66
+
67
+ # Mild scale expansion helps the MLP separate the spiral.
68
+ X = 3.0 * X
69
+
70
+ perm = torch.randperm(X.shape[0])
71
+ return X[perm].to(DEVICE), Y[perm].to(DEVICE)
72
+
73
+
74
+ # -----------------------------
75
+ # Model
76
+ # -----------------------------
77
+
78
+ class TinyMLP(nn.Module):
79
+ def __init__(self, width: int = 128):
80
+ super().__init__()
81
+ self.net = nn.Sequential(
82
+ nn.Linear(2, width),
83
+ nn.ReLU(),
84
+ nn.Linear(width, width),
85
+ nn.ReLU(),
86
+ nn.Linear(width, width),
87
+ nn.ReLU(),
88
+ nn.Linear(width, 2),
89
+ )
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ return self.net(x)
93
+
94
+
95
+ def linear_layers(model: nn.Module) -> List[nn.Linear]:
96
+ return [m for m in model.modules() if isinstance(m, nn.Linear)]
97
+
98
+
99
+ # -----------------------------
100
+ # Surprise Top-K machinery
101
+ # -----------------------------
102
+
103
+ @dataclass(frozen=True)
104
+ class BlockRef:
105
+ layer_index: int
106
+ row_index: int
107
+
108
+
109
+ class SurpriseTopKGradientBuilder:
110
+ """
111
+ Builds approximate gradients after a full backward pass.
112
+
113
+ Block = one output row of a Linear weight matrix, plus its bias element.
114
+
115
+ On active blocks:
116
+ use true gradient.
117
+
118
+ On inactive blocks:
119
+ use predicted gradient from an exponential moving average.
120
+
121
+ Score:
122
+ surprise = ||true_gradient - predicted_gradient|| / (||true_gradient|| + eps)
123
+
124
+ The highest-surprise blocks become active on non-refresh steps.
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ model: nn.Module,
130
+ beta: float = 0.95,
131
+ active_fraction: float = 0.2,
132
+ refresh_interval: int = 10,
133
+ warmup_steps: int = 100,
134
+ eps: float = 1e-12,
135
+ ):
136
+ self.model = model
137
+ self.layers = linear_layers(model)
138
+ self.beta = beta
139
+ self.active_fraction = active_fraction
140
+ self.refresh_interval = refresh_interval
141
+ self.warmup_steps = warmup_steps
142
+ self.eps = eps
143
+
144
+ self.blocks: List[BlockRef] = []
145
+ for li, layer in enumerate(self.layers):
146
+ for row in range(layer.weight.shape[0]):
147
+ self.blocks.append(BlockRef(li, row))
148
+
149
+ self.pred_w: Dict[int, torch.Tensor] = {}
150
+ self.pred_b: Dict[int, torch.Tensor] = {}
151
+
152
+ for li, layer in enumerate(self.layers):
153
+ self.pred_w[li] = torch.zeros_like(layer.weight.data)
154
+ if layer.bias is not None:
155
+ self.pred_b[li] = torch.zeros_like(layer.bias.data)
156
+
157
+ self.scores = torch.ones(len(self.blocks), device=DEVICE)
158
+
159
+ def _choose_active_blocks(self, step: int) -> torch.Tensor:
160
+ n = len(self.blocks)
161
+
162
+ if step < self.warmup_steps:
163
+ return torch.ones(n, dtype=torch.bool, device=DEVICE)
164
+
165
+ if step % self.refresh_interval == 0:
166
+ return torch.ones(n, dtype=torch.bool, device=DEVICE)
167
+
168
+ k = max(1, int(self.active_fraction * n))
169
+ active = torch.zeros(n, dtype=torch.bool, device=DEVICE)
170
+ idx = torch.topk(self.scores, k=k).indices
171
+ active[idx] = True
172
+ return active
173
+
174
+ @torch.no_grad()
175
+ def build_and_install_approx_grads(self, step: int) -> Dict[str, float]:
176
+ active = self._choose_active_blocks(step)
177
+
178
+ true_parts = []
179
+ approx_parts = []
180
+ pred_parts = []
181
+
182
+ # We build full approximate grad tensors, then overwrite .grad so Adam sees
183
+ # the approximate gradient.
184
+ approx_w: Dict[int, torch.Tensor] = {}
185
+ approx_b: Dict[int, torch.Tensor] = {}
186
+ for li, layer in enumerate(self.layers):
187
+ approx_w[li] = torch.zeros_like(layer.weight.grad)
188
+ if layer.bias is not None:
189
+ approx_b[li] = torch.zeros_like(layer.bias.grad)
190
+
191
+ for block_id, block in enumerate(self.blocks):
192
+ li = block.layer_index
193
+ row = block.row_index
194
+ layer = self.layers[li]
195
+ is_active = bool(active[block_id].item())
196
+
197
+ g_w = layer.weight.grad[row].detach().clone()
198
+ p_w = self.pred_w[li][row].detach().clone()
199
+
200
+ if layer.bias is not None:
201
+ g_b = layer.bias.grad[row].detach().clone()
202
+ p_b = self.pred_b[li][row].detach().clone()
203
+ else:
204
+ g_b = None
205
+ p_b = None
206
+
207
+ # Score is computed against the predictor BEFORE updating predictor.
208
+ true_vec_items = [g_w.flatten()]
209
+ pred_vec_items = [p_w.flatten()]
210
+ if g_b is not None:
211
+ true_vec_items.append(g_b.view(1))
212
+ pred_vec_items.append(p_b.view(1))
213
+
214
+ true_vec_block = torch.cat(true_vec_items)
215
+ pred_vec_block = torch.cat(pred_vec_items)
216
+ residual = true_vec_block - pred_vec_block
217
+ self.scores[block_id] = torch.norm(residual) / (torch.norm(true_vec_block) + self.eps)
218
+
219
+ if is_active:
220
+ a_w = g_w
221
+ a_b = g_b
222
+ else:
223
+ a_w = p_w
224
+ a_b = p_b
225
+
226
+ approx_w[li][row] = a_w
227
+ if layer.bias is not None:
228
+ approx_b[li][row] = a_b
229
+
230
+ # Update EMA predictor from the true gradient. We allow this in the
231
+ # prototype because full gradients are being computed for measurement.
232
+ # A speedup version would only update this from refresh/active blocks.
233
+ self.pred_w[li][row].mul_(self.beta).add_(g_w, alpha=1.0 - self.beta)
234
+ if layer.bias is not None:
235
+ self.pred_b[li][row].mul_(self.beta).add_(g_b, alpha=1.0 - self.beta)
236
+
237
+ approx_vec_items = [a_w.flatten()]
238
+ if a_b is not None:
239
+ approx_vec_items.append(a_b.view(1))
240
+
241
+ true_parts.append(true_vec_block)
242
+ pred_parts.append(pred_vec_block)
243
+ approx_parts.append(torch.cat(approx_vec_items))
244
+
245
+ # Install approximate gradients for the optimizer.
246
+ for li, layer in enumerate(self.layers):
247
+ layer.weight.grad.copy_(approx_w[li])
248
+ if layer.bias is not None:
249
+ layer.bias.grad.copy_(approx_b[li])
250
+
251
+ true_vec = torch.cat(true_parts)
252
+ pred_vec = torch.cat(pred_parts)
253
+ approx_vec = torch.cat(approx_parts)
254
+
255
+ cosine = F.cosine_similarity(true_vec, approx_vec, dim=0).item()
256
+ pred_explained = 1.0 - (
257
+ torch.norm(true_vec - pred_vec).pow(2) / (torch.norm(true_vec).pow(2) + self.eps)
258
+ ).item()
259
+
260
+ k20 = max(1, int(0.2 * len(self.blocks)))
261
+ sorted_scores = torch.sort(self.scores.detach(), descending=True).values
262
+ top20_mass = (sorted_scores[:k20].sum() / (sorted_scores.sum() + self.eps)).item()
263
+
264
+ return {
265
+ "active_fraction": float(active.float().mean().item()),
266
+ "cosine_true_vs_approx": cosine,
267
+ "pred_explained_fraction": pred_explained,
268
+ "top20_surprise_mass": top20_mass,
269
+ }
270
+
271
+
272
+ # -----------------------------
273
+ # Metrics and training
274
+ # -----------------------------
275
+
276
+ def accuracy(model: nn.Module, X: torch.Tensor, y: torch.Tensor) -> float:
277
+ model.eval()
278
+ with torch.no_grad():
279
+ pred = model(X).argmax(dim=1)
280
+ return (pred == y).float().mean().item()
281
+
282
+
283
+ def train_baseline(
284
+ X: torch.Tensor,
285
+ y: torch.Tensor,
286
+ steps: int = 2000,
287
+ batch_size: int = 256,
288
+ lr: float = 1e-3,
289
+ ) -> Tuple[nn.Module, List[Dict[str, float]]]:
290
+ model = TinyMLP().to(DEVICE)
291
+ opt = torch.optim.Adam(model.parameters(), lr=lr)
292
+ history: List[Dict[str, float]] = []
293
+
294
+ for step in range(steps):
295
+ model.train()
296
+ idx = torch.randint(0, X.shape[0], (batch_size,), device=DEVICE)
297
+ xb, yb = X[idx], y[idx]
298
+
299
+ loss = F.cross_entropy(model(xb), yb)
300
+
301
+ opt.zero_grad(set_to_none=True)
302
+ loss.backward()
303
+ opt.step()
304
+
305
+ if step % 100 == 0 or step == steps - 1:
306
+ history.append({
307
+ "step": step,
308
+ "loss": float(loss.item()),
309
+ "accuracy": accuracy(model, X, y),
310
+ })
311
+
312
+ return model, history
313
+
314
+
315
+ def train_surprise_topk(
316
+ X: torch.Tensor,
317
+ y: torch.Tensor,
318
+ steps: int = 2000,
319
+ batch_size: int = 256,
320
+ lr: float = 1e-3,
321
+ active_fraction: float = 0.2,
322
+ refresh_interval: int = 10,
323
+ warmup_steps: int = 100,
324
+ beta: float = 0.95,
325
+ ) -> Tuple[nn.Module, List[Dict[str, float]]]:
326
+ model = TinyMLP().to(DEVICE)
327
+ opt = torch.optim.Adam(model.parameters(), lr=lr)
328
+ builder = SurpriseTopKGradientBuilder(
329
+ model,
330
+ beta=beta,
331
+ active_fraction=active_fraction,
332
+ refresh_interval=refresh_interval,
333
+ warmup_steps=warmup_steps,
334
+ )
335
+
336
+ history: List[Dict[str, float]] = []
337
+
338
+ for step in range(steps):
339
+ model.train()
340
+ idx = torch.randint(0, X.shape[0], (batch_size,), device=DEVICE)
341
+ xb, yb = X[idx], y[idx]
342
+
343
+ loss = F.cross_entropy(model(xb), yb)
344
+
345
+ opt.zero_grad(set_to_none=True)
346
+ loss.backward()
347
+
348
+ diagnostics = builder.build_and_install_approx_grads(step)
349
+ opt.step()
350
+
351
+ if step % 100 == 0 or step == steps - 1:
352
+ history.append({
353
+ "step": step,
354
+ "loss": float(loss.item()),
355
+ "accuracy": accuracy(model, X, y),
356
+ **diagnostics,
357
+ })
358
+
359
+ return model, history
360
+
361
+
362
+ def print_last(label: str, history: List[Dict[str, float]]) -> None:
363
+ print(f"\n{label}")
364
+ for k, v in history[-1].items():
365
+ if isinstance(v, float):
366
+ print(f" {k:28s}: {v:.4f}")
367
+ else:
368
+ print(f" {k:28s}: {v}")
369
+
370
+
371
+ def print_checkpoints(label: str, history: List[Dict[str, float]]) -> None:
372
+ print(f"\n{label} checkpoints:")
373
+ stride = max(1, len(history) // 8)
374
+ for row in history[::stride]:
375
+ extra = ""
376
+ if "cosine_true_vs_approx" in row:
377
+ extra = (
378
+ f" active={row['active_fraction']:.2f}"
379
+ f" cos={row['cosine_true_vs_approx']:.3f}"
380
+ f" pred_expl={row['pred_explained_fraction']:.3f}"
381
+ f" top20_mass={row['top20_surprise_mass']:.3f}"
382
+ )
383
+ print(
384
+ f"step={row['step']:4d} "
385
+ f"loss={row['loss']:.4f} "
386
+ f"acc={row['accuracy']:.3f}"
387
+ f"{extra}"
388
+ )
389
+
390
+
391
+ def main() -> None:
392
+ X, y = make_spiral()
393
+
394
+ baseline_model, baseline_hist = train_baseline(X, y)
395
+ surprise_model, surprise_hist = train_surprise_topk(
396
+ X,
397
+ y,
398
+ active_fraction=0.2,
399
+ refresh_interval=10,
400
+ warmup_steps=100,
401
+ beta=0.95,
402
+ )
403
+
404
+ print_last("Baseline full Adam", baseline_hist)
405
+ print_last("Surprise Top-K simulated Adam", surprise_hist)
406
+
407
+ print_checkpoints("Baseline", baseline_hist)
408
+ print_checkpoints("Surprise Top-K", surprise_hist)
409
+
410
+ print("\nHow to read this:")
411
+ print(" cos near 1.0 => approximate update points like the true gradient")
412
+ print(" pred_expl > 0 => predictor beats zero as a gradient guess")
413
+ print(" top20_mass high => surprise is heavy-tailed / concentrated")
414
+ print(" accuracy close => approximation did not wreck training")
415
+
416
+
417
+ if __name__ == "__main__":
418
+ main()
experiments/surprise_topk_gradient_prototype-v3.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Surprise Top-K Gradient Prototype, v3
3
+
4
+ Goal
5
+ ----
6
+ Test the hypothesis:
7
+
8
+ gradient_t ≈ predicted_gradient_t + sparse_surprising_residual_t
9
+
10
+ What changed from v2
11
+ --------------------
12
+ 1. Checkpoints no longer accidentally land mostly on refresh steps.
13
+ 2. Surprise mass is measured using absolute residual norm, not a normalized ratio
14
+ that explodes when true gradients are tiny.
15
+ 3. We track both:
16
+ - score_for_selection: normalized surprise, used to decide active blocks
17
+ - residual_mass: absolute residual norm, used to test heavy-tailed structure
18
+ 4. We print separate summaries for refresh steps and sparse steps.
19
+ 5. We compare surprise-top-k against magnitude-top-k and random-top-k policies.
20
+
21
+ Important caveat
22
+ ----------------
23
+ This still computes the full gradient every step. That is intentional. This is a
24
+ hypothesis test: does a sparse "surprising residual" preserve the useful update
25
+ signal? If yes, a later version can try to skip real backward computation.
26
+
27
+ Run
28
+ ---
29
+ python3 surprise_topk_gradient_prototype.py
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ import math
35
+ import random
36
+ from dataclasses import dataclass
37
+ from typing import Dict, List, Literal, Tuple
38
+
39
+ import torch
40
+ import torch.nn as nn
41
+ import torch.nn.functional as F
42
+
43
+
44
+ SEED = 7
45
+ random.seed(SEED)
46
+ torch.manual_seed(SEED)
47
+
48
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
49
+ Policy = Literal["surprise", "magnitude", "random"]
50
+
51
+
52
+ # -----------------------------
53
+ # Toy data: 2-class spiral
54
+ # -----------------------------
55
+
56
+ def make_spiral(n_per_class: int = 1024, noise: float = 0.12) -> Tuple[torch.Tensor, torch.Tensor]:
57
+ xs = []
58
+ ys = []
59
+
60
+ for class_id in range(2):
61
+ r = torch.linspace(0.0, 1.0, n_per_class)
62
+ theta = class_id * math.pi + r * 4.0 * math.pi
63
+ theta = theta + torch.randn(n_per_class) * noise
64
+
65
+ x = torch.stack([r * torch.sin(theta), r * torch.cos(theta)], dim=1)
66
+ y = torch.full((n_per_class,), class_id, dtype=torch.long)
67
+
68
+ xs.append(x)
69
+ ys.append(y)
70
+
71
+ X = torch.cat(xs, dim=0)
72
+ Y = torch.cat(ys, dim=0)
73
+ X = 3.0 * X
74
+
75
+ perm = torch.randperm(X.shape[0])
76
+ return X[perm].to(DEVICE), Y[perm].to(DEVICE)
77
+
78
+
79
+ # -----------------------------
80
+ # Model
81
+ # -----------------------------
82
+
83
+ class TinyMLP(nn.Module):
84
+ def __init__(self, width: int = 128):
85
+ super().__init__()
86
+ self.net = nn.Sequential(
87
+ nn.Linear(2, width),
88
+ nn.ReLU(),
89
+ nn.Linear(width, width),
90
+ nn.ReLU(),
91
+ nn.Linear(width, width),
92
+ nn.ReLU(),
93
+ nn.Linear(width, 2),
94
+ )
95
+
96
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
97
+ return self.net(x)
98
+
99
+
100
+ def linear_layers(model: nn.Module) -> List[nn.Linear]:
101
+ return [m for m in model.modules() if isinstance(m, nn.Linear)]
102
+
103
+
104
+ # -----------------------------
105
+ # Surprise Top-K machinery
106
+ # -----------------------------
107
+
108
+ @dataclass(frozen=True)
109
+ class BlockRef:
110
+ layer_index: int
111
+ row_index: int
112
+
113
+
114
+ class SparseGradientBuilder:
115
+ """
116
+ Builds approximate gradients after a full backward pass.
117
+
118
+ Block = one output row of a Linear weight matrix, plus its bias element.
119
+
120
+ Approx gradient:
121
+ active blocks -> true gradient
122
+ inactive blocks -> EMA-predicted gradient
123
+
124
+ Selection policies:
125
+ surprise -> choose blocks where true gradient differs most from prediction
126
+ magnitude -> choose blocks with largest true gradient norm
127
+ random -> choose random blocks
128
+ """
129
+
130
+ def __init__(
131
+ self,
132
+ model: nn.Module,
133
+ policy: Policy = "surprise",
134
+ beta: float = 0.95,
135
+ active_fraction: float = 0.2,
136
+ refresh_interval: int = 10,
137
+ warmup_steps: int = 100,
138
+ eps: float = 1e-12,
139
+ ):
140
+ self.model = model
141
+ self.layers = linear_layers(model)
142
+ self.policy = policy
143
+ self.beta = beta
144
+ self.active_fraction = active_fraction
145
+ self.refresh_interval = refresh_interval
146
+ self.warmup_steps = warmup_steps
147
+ self.eps = eps
148
+
149
+ self.blocks: List[BlockRef] = []
150
+ for li, layer in enumerate(self.layers):
151
+ for row in range(layer.weight.shape[0]):
152
+ self.blocks.append(BlockRef(li, row))
153
+
154
+ self.pred_w: Dict[int, torch.Tensor] = {}
155
+ self.pred_b: Dict[int, torch.Tensor] = {}
156
+
157
+ for li, layer in enumerate(self.layers):
158
+ self.pred_w[li] = torch.zeros_like(layer.weight.data)
159
+ if layer.bias is not None:
160
+ self.pred_b[li] = torch.zeros_like(layer.bias.data)
161
+
162
+ self.selection_score = torch.ones(len(self.blocks), device=DEVICE)
163
+ self.residual_mass = torch.ones(len(self.blocks), device=DEVICE)
164
+ self.gradient_mass = torch.ones(len(self.blocks), device=DEVICE)
165
+
166
+ def _is_refresh_step(self, step: int) -> bool:
167
+ return step < self.warmup_steps or step % self.refresh_interval == 0
168
+
169
+ def _choose_active_blocks(self, step: int) -> torch.Tensor:
170
+ n = len(self.blocks)
171
+
172
+ if self._is_refresh_step(step):
173
+ return torch.ones(n, dtype=torch.bool, device=DEVICE)
174
+
175
+ k = max(1, int(self.active_fraction * n))
176
+ active = torch.zeros(n, dtype=torch.bool, device=DEVICE)
177
+
178
+ if self.policy == "surprise":
179
+ idx = torch.topk(self.selection_score, k=k).indices
180
+ elif self.policy == "magnitude":
181
+ idx = torch.topk(self.gradient_mass, k=k).indices
182
+ elif self.policy == "random":
183
+ idx = torch.randperm(n, device=DEVICE)[:k]
184
+ else:
185
+ raise ValueError(f"Unknown policy: {self.policy}")
186
+
187
+ active[idx] = True
188
+ return active
189
+
190
+ @torch.no_grad()
191
+ def build_and_install_approx_grads(self, step: int) -> Dict[str, float]:
192
+ active = self._choose_active_blocks(step)
193
+ is_refresh = self._is_refresh_step(step)
194
+
195
+ true_parts = []
196
+ approx_parts = []
197
+ pred_parts = []
198
+
199
+ approx_w: Dict[int, torch.Tensor] = {}
200
+ approx_b: Dict[int, torch.Tensor] = {}
201
+ for li, layer in enumerate(self.layers):
202
+ approx_w[li] = torch.zeros_like(layer.weight.grad)
203
+ if layer.bias is not None:
204
+ approx_b[li] = torch.zeros_like(layer.bias.grad)
205
+
206
+ for block_id, block in enumerate(self.blocks):
207
+ li = block.layer_index
208
+ row = block.row_index
209
+ layer = self.layers[li]
210
+ is_active = bool(active[block_id].item())
211
+
212
+ g_w = layer.weight.grad[row].detach().clone()
213
+ p_w = self.pred_w[li][row].detach().clone()
214
+
215
+ if layer.bias is not None:
216
+ g_b = layer.bias.grad[row].detach().clone()
217
+ p_b = self.pred_b[li][row].detach().clone()
218
+ else:
219
+ g_b = None
220
+ p_b = None
221
+
222
+ true_vec_items = [g_w.flatten()]
223
+ pred_vec_items = [p_w.flatten()]
224
+ if g_b is not None:
225
+ true_vec_items.append(g_b.view(1))
226
+ pred_vec_items.append(p_b.view(1))
227
+
228
+ true_vec_block = torch.cat(true_vec_items)
229
+ pred_vec_block = torch.cat(pred_vec_items)
230
+ residual = true_vec_block - pred_vec_block
231
+
232
+ grad_norm = torch.norm(true_vec_block)
233
+ residual_norm = torch.norm(residual)
234
+
235
+ self.gradient_mass[block_id] = grad_norm
236
+ self.residual_mass[block_id] = residual_norm
237
+
238
+ # Selection score deliberately normalized, but with a floor so tiny
239
+ # gradients do not create absurd ratios.
240
+ denom = torch.maximum(grad_norm, torch.tensor(1e-6, device=DEVICE))
241
+ self.selection_score[block_id] = residual_norm / denom
242
+
243
+ if is_active:
244
+ a_w = g_w
245
+ a_b = g_b
246
+ else:
247
+ a_w = p_w
248
+ a_b = p_b
249
+
250
+ approx_w[li][row] = a_w
251
+ if layer.bias is not None:
252
+ approx_b[li][row] = a_b
253
+
254
+ # Prototype choice: update EMA from true gradient on every step,
255
+ # because we computed full gradients for measurement. A speedup version
256
+ # would only update this on refresh/active blocks.
257
+ self.pred_w[li][row].mul_(self.beta).add_(g_w, alpha=1.0 - self.beta)
258
+ if layer.bias is not None:
259
+ self.pred_b[li][row].mul_(self.beta).add_(g_b, alpha=1.0 - self.beta)
260
+
261
+ approx_vec_items = [a_w.flatten()]
262
+ if a_b is not None:
263
+ approx_vec_items.append(a_b.view(1))
264
+
265
+ true_parts.append(true_vec_block)
266
+ pred_parts.append(pred_vec_block)
267
+ approx_parts.append(torch.cat(approx_vec_items))
268
+
269
+ for li, layer in enumerate(self.layers):
270
+ layer.weight.grad.copy_(approx_w[li])
271
+ if layer.bias is not None:
272
+ layer.bias.grad.copy_(approx_b[li])
273
+
274
+ true_vec = torch.cat(true_parts)
275
+ pred_vec = torch.cat(pred_parts)
276
+ approx_vec = torch.cat(approx_parts)
277
+
278
+ true_norm = torch.norm(true_vec)
279
+ pred_error_norm = torch.norm(true_vec - pred_vec)
280
+
281
+ cosine = F.cosine_similarity(true_vec, approx_vec, dim=0).item()
282
+ pred_explained = 1.0 - (
283
+ pred_error_norm.pow(2) / (true_norm.pow(2) + self.eps)
284
+ ).item()
285
+
286
+ k20 = max(1, int(0.2 * len(self.blocks)))
287
+
288
+ sorted_residual = torch.sort(self.residual_mass.detach(), descending=True).values
289
+ top20_residual_mass = (sorted_residual[:k20].sum() / (sorted_residual.sum() + self.eps)).item()
290
+
291
+ sorted_gradient = torch.sort(self.gradient_mass.detach(), descending=True).values
292
+ top20_gradient_mass = (sorted_gradient[:k20].sum() / (sorted_gradient.sum() + self.eps)).item()
293
+
294
+ return {
295
+ "is_refresh": float(is_refresh),
296
+ "active_fraction": float(active.float().mean().item()),
297
+ "cosine_true_vs_approx": cosine,
298
+ "pred_explained_fraction": pred_explained,
299
+ "top20_residual_mass": top20_residual_mass,
300
+ "top20_gradient_mass": top20_gradient_mass,
301
+ "true_grad_norm": float(true_norm.item()),
302
+ "pred_error_norm": float(pred_error_norm.item()),
303
+ }
304
+
305
+
306
+ # -----------------------------
307
+ # Metrics and training
308
+ # -----------------------------
309
+
310
+ def accuracy(model: nn.Module, X: torch.Tensor, y: torch.Tensor) -> float:
311
+ model.eval()
312
+ with torch.no_grad():
313
+ pred = model(X).argmax(dim=1)
314
+ return (pred == y).float().mean().item()
315
+
316
+
317
+ def train_baseline(
318
+ X: torch.Tensor,
319
+ y: torch.Tensor,
320
+ steps: int = 2000,
321
+ batch_size: int = 256,
322
+ lr: float = 1e-3,
323
+ ) -> Tuple[nn.Module, List[Dict[str, float]]]:
324
+ model = TinyMLP().to(DEVICE)
325
+ opt = torch.optim.Adam(model.parameters(), lr=lr)
326
+ history: List[Dict[str, float]] = []
327
+
328
+ for step in range(steps):
329
+ model.train()
330
+ idx = torch.randint(0, X.shape[0], (batch_size,), device=DEVICE)
331
+ xb, yb = X[idx], y[idx]
332
+
333
+ loss = F.cross_entropy(model(xb), yb)
334
+
335
+ opt.zero_grad(set_to_none=True)
336
+ loss.backward()
337
+ opt.step()
338
+
339
+ if step % 97 == 0 or step == steps - 1:
340
+ history.append({
341
+ "step": step,
342
+ "loss": float(loss.item()),
343
+ "accuracy": accuracy(model, X, y),
344
+ })
345
+
346
+ return model, history
347
+
348
+
349
+ def train_sparse_policy(
350
+ X: torch.Tensor,
351
+ y: torch.Tensor,
352
+ policy: Policy,
353
+ steps: int = 2000,
354
+ batch_size: int = 256,
355
+ lr: float = 1e-3,
356
+ active_fraction: float = 0.2,
357
+ refresh_interval: int = 10,
358
+ warmup_steps: int = 100,
359
+ beta: float = 0.95,
360
+ ) -> Tuple[nn.Module, List[Dict[str, float]]]:
361
+ model = TinyMLP().to(DEVICE)
362
+ opt = torch.optim.Adam(model.parameters(), lr=lr)
363
+ builder = SparseGradientBuilder(
364
+ model,
365
+ policy=policy,
366
+ beta=beta,
367
+ active_fraction=active_fraction,
368
+ refresh_interval=refresh_interval,
369
+ warmup_steps=warmup_steps,
370
+ )
371
+
372
+ history: List[Dict[str, float]] = []
373
+
374
+ for step in range(steps):
375
+ model.train()
376
+ idx = torch.randint(0, X.shape[0], (batch_size,), device=DEVICE)
377
+ xb, yb = X[idx], y[idx]
378
+
379
+ loss = F.cross_entropy(model(xb), yb)
380
+
381
+ opt.zero_grad(set_to_none=True)
382
+ loss.backward()
383
+
384
+ diagnostics = builder.build_and_install_approx_grads(step)
385
+ opt.step()
386
+
387
+ if step % 97 == 0 or step == steps - 1:
388
+ history.append({
389
+ "step": step,
390
+ "loss": float(loss.item()),
391
+ "accuracy": accuracy(model, X, y),
392
+ **diagnostics,
393
+ })
394
+
395
+ return model, history
396
+
397
+
398
+ def avg_sparse_metric(history: List[Dict[str, float]], key: str) -> float:
399
+ vals = [row[key] for row in history if row.get("is_refresh", 0.0) == 0.0]
400
+ if not vals:
401
+ return float("nan")
402
+ return sum(vals) / len(vals)
403
+
404
+
405
+ def print_last(label: str, history: List[Dict[str, float]]) -> None:
406
+ print(f"\n{label}")
407
+ for k, v in history[-1].items():
408
+ if isinstance(v, float):
409
+ print(f" {k:28s}: {v:.4f}")
410
+ else:
411
+ print(f" {k:28s}: {v}")
412
+
413
+
414
+ def print_sparse_summary(label: str, history: List[Dict[str, float]]) -> None:
415
+ last = history[-1]
416
+ print(f"\n{label} sparse-step averages from logged checkpoints")
417
+ for key in [
418
+ "cosine_true_vs_approx",
419
+ "pred_explained_fraction",
420
+ "top20_residual_mass",
421
+ "top20_gradient_mass",
422
+ ]:
423
+ print(f" avg {key:24s}: {avg_sparse_metric(history, key):.4f}")
424
+ print(f" final accuracy : {last['accuracy']:.4f}")
425
+ print(f" final loss : {last['loss']:.4f}")
426
+
427
+
428
+ def print_checkpoints(label: str, history: List[Dict[str, float]]) -> None:
429
+ print(f"\n{label} checkpoints:")
430
+ stride = max(1, len(history) // 8)
431
+ for row in history[::stride]:
432
+ extra = ""
433
+ if "cosine_true_vs_approx" in row:
434
+ extra = (
435
+ f" refresh={int(row['is_refresh'])}"
436
+ f" active={row['active_fraction']:.2f}"
437
+ f" cos={row['cosine_true_vs_approx']:.3f}"
438
+ f" pred_expl={row['pred_explained_fraction']:.3f}"
439
+ f" top20_resid={row['top20_residual_mass']:.3f}"
440
+ f" top20_grad={row['top20_gradient_mass']:.3f}"
441
+ )
442
+ print(
443
+ f"step={row['step']:4d} "
444
+ f"loss={row['loss']:.4f} "
445
+ f"acc={row['accuracy']:.3f}"
446
+ f"{extra}"
447
+ )
448
+
449
+
450
+ def main() -> None:
451
+ X, y = make_spiral()
452
+
453
+ baseline_model, baseline_hist = train_baseline(X, y)
454
+
455
+ results = {}
456
+ for policy in ["surprise", "magnitude", "random"]:
457
+ _, hist = train_sparse_policy(
458
+ X,
459
+ y,
460
+ policy=policy,
461
+ active_fraction=0.2,
462
+ refresh_interval=10,
463
+ warmup_steps=100,
464
+ beta=0.95,
465
+ )
466
+ results[policy] = hist
467
+
468
+ print_last("Baseline full Adam", baseline_hist)
469
+ for policy, hist in results.items():
470
+ print_last(f"{policy.title()} Top-K simulated Adam", hist)
471
+
472
+ print_checkpoints("Baseline", baseline_hist)
473
+ for policy, hist in results.items():
474
+ print_checkpoints(f"{policy.title()} Top-K", hist)
475
+ print_sparse_summary(f"{policy.title()} Top-K", hist)
476
+
477
+ print("\nHow to read this:")
478
+ print(" refresh=0 => a real sparse/approximate logged step")
479
+ print(" cos near 1.0 => approximate update points like the true gradient")
480
+ print(" pred_expl > 0 => predictor beats zero as a gradient guess")
481
+ print(" top20_resid high => prediction error is heavy-tailed/concentrated")
482
+ print(" top20_grad high => raw gradient mass is heavy-tailed/concentrated")
483
+ print(" surprise > baselines => surprise is doing more than ordinary top-k/random")
484
+
485
+
486
+ if __name__ == "__main__":
487
+ main()
experiments/surprise_topk_gradient_prototype-v4.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Surprise / Predicted-Magnitude Top-K Gradient Prototype, v4
3
+
4
+ Goal
5
+ ----
6
+ Test the practical version of the hypothesis:
7
+
8
+ The gradient/update signal is heavy-tailed, and the high-mass blocks are
9
+ predictable enough that we can choose where to spend backward/update compute.
10
+
11
+ What v4 adds
12
+ ------------
13
+ A new policy:
14
+
15
+ predicted_magnitude
16
+
17
+ This selects active blocks using only a historical EMA of block gradient norms,
18
+ not the current gradient. This is much closer to something that could eventually
19
+ save backward computation, because the active set is known before using the
20
+ current full gradient.
21
+
22
+ Policies compared
23
+ -----------------
24
+ 1. surprise
25
+ Select blocks whose current gradient is least predictable from EMA gradient.
26
+ This is mostly a diagnostic/oracle because it uses the current gradient.
27
+
28
+ 2. magnitude
29
+ Select blocks with largest current gradient norm.
30
+ This is an oracle upper-bound for simple top-k block sparsification.
31
+
32
+ 3. predicted_magnitude
33
+ Select blocks with largest EMA-predicted gradient norm.
34
+ This is the important practical test.
35
+
36
+ 4. random
37
+ Control baseline.
38
+
39
+ Important caveat
40
+ ----------------
41
+ This still computes the full gradient every step. That is intentional. We are
42
+ measuring whether the active-set prediction would have worked. Actual speedup
43
+ would require skipping or restricting backward computation for inactive blocks.
44
+
45
+ Run
46
+ ---
47
+ python3 surprise_topk_gradient_prototype.py
48
+ """
49
+
50
+ from __future__ import annotations
51
+
52
+ import math
53
+ import random
54
+ from dataclasses import dataclass
55
+ from typing import Dict, List, Literal, Tuple
56
+
57
+ import torch
58
+ import torch.nn as nn
59
+ import torch.nn.functional as F
60
+
61
+
62
+ SEED = 7
63
+ random.seed(SEED)
64
+ torch.manual_seed(SEED)
65
+
66
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
67
+ Policy = Literal["surprise", "magnitude", "predicted_magnitude", "random"]
68
+
69
+
70
+ # -----------------------------
71
+ # Toy data: 2-class spiral
72
+ # -----------------------------
73
+
74
+ def make_spiral(n_per_class: int = 1024, noise: float = 0.12) -> Tuple[torch.Tensor, torch.Tensor]:
75
+ xs = []
76
+ ys = []
77
+
78
+ for class_id in range(2):
79
+ r = torch.linspace(0.0, 1.0, n_per_class)
80
+ theta = class_id * math.pi + r * 4.0 * math.pi
81
+ theta = theta + torch.randn(n_per_class) * noise
82
+
83
+ x = torch.stack([r * torch.sin(theta), r * torch.cos(theta)], dim=1)
84
+ y = torch.full((n_per_class,), class_id, dtype=torch.long)
85
+
86
+ xs.append(x)
87
+ ys.append(y)
88
+
89
+ X = torch.cat(xs, dim=0)
90
+ Y = torch.cat(ys, dim=0)
91
+ X = 3.0 * X
92
+
93
+ perm = torch.randperm(X.shape[0])
94
+ return X[perm].to(DEVICE), Y[perm].to(DEVICE)
95
+
96
+
97
+ # -----------------------------
98
+ # Model
99
+ # -----------------------------
100
+
101
+ class TinyMLP(nn.Module):
102
+ def __init__(self, width: int = 128):
103
+ super().__init__()
104
+ self.net = nn.Sequential(
105
+ nn.Linear(2, width),
106
+ nn.ReLU(),
107
+ nn.Linear(width, width),
108
+ nn.ReLU(),
109
+ nn.Linear(width, width),
110
+ nn.ReLU(),
111
+ nn.Linear(width, 2),
112
+ )
113
+
114
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
115
+ return self.net(x)
116
+
117
+
118
+ def linear_layers(model: nn.Module) -> List[nn.Linear]:
119
+ return [m for m in model.modules() if isinstance(m, nn.Linear)]
120
+
121
+
122
+ # -----------------------------
123
+ # Sparse gradient machinery
124
+ # -----------------------------
125
+
126
+ @dataclass(frozen=True)
127
+ class BlockRef:
128
+ layer_index: int
129
+ row_index: int
130
+
131
+
132
+ class SparseGradientBuilder:
133
+ """
134
+ Builds approximate gradients after a full backward pass.
135
+
136
+ Block = one output row of a Linear weight matrix, plus its bias element.
137
+
138
+ Approx gradient:
139
+ active blocks -> true gradient
140
+ inactive blocks -> EMA-predicted gradient
141
+
142
+ Selection policies:
143
+ surprise -> largest current residual: ||g - pred_g||
144
+ magnitude -> largest current gradient norm: ||g||
145
+ predicted_magnitude -> largest historical EMA gradient norm
146
+ random -> random active blocks
147
+ """
148
+
149
+ def __init__(
150
+ self,
151
+ model: nn.Module,
152
+ policy: Policy = "predicted_magnitude",
153
+ grad_beta: float = 0.95,
154
+ mass_beta: float = 0.95,
155
+ active_fraction: float = 0.2,
156
+ refresh_interval: int = 10,
157
+ warmup_steps: int = 100,
158
+ eps: float = 1e-12,
159
+ ):
160
+ self.model = model
161
+ self.layers = linear_layers(model)
162
+ self.policy = policy
163
+ self.grad_beta = grad_beta
164
+ self.mass_beta = mass_beta
165
+ self.active_fraction = active_fraction
166
+ self.refresh_interval = refresh_interval
167
+ self.warmup_steps = warmup_steps
168
+ self.eps = eps
169
+
170
+ self.blocks: List[BlockRef] = []
171
+ for li, layer in enumerate(self.layers):
172
+ for row in range(layer.weight.shape[0]):
173
+ self.blocks.append(BlockRef(li, row))
174
+
175
+ self.pred_w: Dict[int, torch.Tensor] = {}
176
+ self.pred_b: Dict[int, torch.Tensor] = {}
177
+
178
+ for li, layer in enumerate(self.layers):
179
+ self.pred_w[li] = torch.zeros_like(layer.weight.data)
180
+ if layer.bias is not None:
181
+ self.pred_b[li] = torch.zeros_like(layer.bias.data)
182
+
183
+ n = len(self.blocks)
184
+ self.current_gradient_mass = torch.ones(n, device=DEVICE)
185
+ self.current_residual_mass = torch.ones(n, device=DEVICE)
186
+ self.predicted_gradient_mass = torch.ones(n, device=DEVICE)
187
+ self.selection_score = torch.ones(n, device=DEVICE)
188
+
189
+ # Active-set stability diagnostic.
190
+ self.prev_active = torch.zeros(n, dtype=torch.bool, device=DEVICE)
191
+
192
+ def _is_refresh_step(self, step: int) -> bool:
193
+ return step < self.warmup_steps or step % self.refresh_interval == 0
194
+
195
+ def _choose_active_blocks(self, step: int) -> torch.Tensor:
196
+ n = len(self.blocks)
197
+
198
+ if self._is_refresh_step(step):
199
+ return torch.ones(n, dtype=torch.bool, device=DEVICE)
200
+
201
+ k = max(1, int(self.active_fraction * n))
202
+ active = torch.zeros(n, dtype=torch.bool, device=DEVICE)
203
+
204
+ if self.policy == "surprise":
205
+ # Uses last step's residual score when choosing before the current
206
+ # gradient is processed. After full gradient is computed below, this
207
+ # becomes an oracle-ish diagnostic of residual concentration.
208
+ idx = torch.topk(self.current_residual_mass, k=k).indices
209
+ elif self.policy == "magnitude":
210
+ # Uses last observed current_gradient_mass from previous step, not the
211
+ # current one, at selection time. Still a strong baseline.
212
+ idx = torch.topk(self.current_gradient_mass, k=k).indices
213
+ elif self.policy == "predicted_magnitude":
214
+ # This is the important practical policy: choose from historical EMA.
215
+ idx = torch.topk(self.predicted_gradient_mass, k=k).indices
216
+ elif self.policy == "random":
217
+ idx = torch.randperm(n, device=DEVICE)[:k]
218
+ else:
219
+ raise ValueError(f"Unknown policy: {self.policy}")
220
+
221
+ active[idx] = True
222
+ return active
223
+
224
+ @staticmethod
225
+ def _topk_mask(values: torch.Tensor, fraction: float) -> torch.Tensor:
226
+ n = values.numel()
227
+ k = max(1, int(fraction * n))
228
+ mask = torch.zeros(n, dtype=torch.bool, device=values.device)
229
+ mask[torch.topk(values, k=k).indices] = True
230
+ return mask
231
+
232
+ @staticmethod
233
+ def _jaccard(a: torch.Tensor, b: torch.Tensor) -> float:
234
+ inter = (a & b).sum().float()
235
+ union = (a | b).sum().float()
236
+ return float((inter / torch.clamp(union, min=1.0)).item())
237
+
238
+ @torch.no_grad()
239
+ def build_and_install_approx_grads(self, step: int) -> Dict[str, float]:
240
+ active = self._choose_active_blocks(step)
241
+ is_refresh = self._is_refresh_step(step)
242
+
243
+ true_parts = []
244
+ approx_parts = []
245
+ pred_parts = []
246
+
247
+ approx_w: Dict[int, torch.Tensor] = {}
248
+ approx_b: Dict[int, torch.Tensor] = {}
249
+ for li, layer in enumerate(self.layers):
250
+ approx_w[li] = torch.zeros_like(layer.weight.grad)
251
+ if layer.bias is not None:
252
+ approx_b[li] = torch.zeros_like(layer.bias.grad)
253
+
254
+ # Temporary arrays for current-step measurements.
255
+ new_gradient_mass = torch.zeros_like(self.current_gradient_mass)
256
+ new_residual_mass = torch.zeros_like(self.current_residual_mass)
257
+
258
+ for block_id, block in enumerate(self.blocks):
259
+ li = block.layer_index
260
+ row = block.row_index
261
+ layer = self.layers[li]
262
+ is_active = bool(active[block_id].item())
263
+
264
+ g_w = layer.weight.grad[row].detach().clone()
265
+ p_w = self.pred_w[li][row].detach().clone()
266
+
267
+ if layer.bias is not None:
268
+ g_b = layer.bias.grad[row].detach().clone()
269
+ p_b = self.pred_b[li][row].detach().clone()
270
+ else:
271
+ g_b = None
272
+ p_b = None
273
+
274
+ true_vec_items = [g_w.flatten()]
275
+ pred_vec_items = [p_w.flatten()]
276
+ if g_b is not None:
277
+ true_vec_items.append(g_b.view(1))
278
+ pred_vec_items.append(p_b.view(1))
279
+
280
+ true_vec_block = torch.cat(true_vec_items)
281
+ pred_vec_block = torch.cat(pred_vec_items)
282
+ residual_vec = true_vec_block - pred_vec_block
283
+
284
+ grad_norm = torch.norm(true_vec_block)
285
+ residual_norm = torch.norm(residual_vec)
286
+
287
+ new_gradient_mass[block_id] = grad_norm
288
+ new_residual_mass[block_id] = residual_norm
289
+
290
+ if is_active:
291
+ a_w = g_w
292
+ a_b = g_b
293
+ else:
294
+ a_w = p_w
295
+ a_b = p_b
296
+
297
+ approx_w[li][row] = a_w
298
+ if layer.bias is not None:
299
+ approx_b[li][row] = a_b
300
+
301
+ # Prototype choice: update predictors from true gradient because the
302
+ # full gradient is available for measurement. A speedup version would
303
+ # update fully only on refresh steps and partially on active blocks.
304
+ self.pred_w[li][row].mul_(self.grad_beta).add_(g_w, alpha=1.0 - self.grad_beta)
305
+ if layer.bias is not None:
306
+ self.pred_b[li][row].mul_(self.grad_beta).add_(g_b, alpha=1.0 - self.grad_beta)
307
+
308
+ approx_vec_items = [a_w.flatten()]
309
+ if a_b is not None:
310
+ approx_vec_items.append(a_b.view(1))
311
+
312
+ true_parts.append(true_vec_block)
313
+ pred_parts.append(pred_vec_block)
314
+ approx_parts.append(torch.cat(approx_vec_items))
315
+
316
+ # Install approximate gradients for Adam.
317
+ for li, layer in enumerate(self.layers):
318
+ layer.weight.grad.copy_(approx_w[li])
319
+ if layer.bias is not None:
320
+ layer.bias.grad.copy_(approx_b[li])
321
+
322
+ true_vec = torch.cat(true_parts)
323
+ pred_vec = torch.cat(pred_parts)
324
+ approx_vec = torch.cat(approx_parts)
325
+
326
+ true_norm = torch.norm(true_vec)
327
+ pred_error_norm = torch.norm(true_vec - pred_vec)
328
+
329
+ cosine = F.cosine_similarity(true_vec, approx_vec, dim=0).item()
330
+ pred_explained = 1.0 - (
331
+ pred_error_norm.pow(2) / (true_norm.pow(2) + self.eps)
332
+ ).item()
333
+
334
+ # Oracle masks for diagnostics after seeing the true current gradient.
335
+ oracle_magnitude_mask = self._topk_mask(new_gradient_mass, self.active_fraction)
336
+ oracle_residual_mask = self._topk_mask(new_residual_mass, self.active_fraction)
337
+ predicted_magnitude_mask = self._topk_mask(self.predicted_gradient_mass, self.active_fraction)
338
+
339
+ active_vs_oracle_mag = self._jaccard(active, oracle_magnitude_mask)
340
+ active_vs_oracle_resid = self._jaccard(active, oracle_residual_mask)
341
+ predmag_vs_oracle_mag = self._jaccard(predicted_magnitude_mask, oracle_magnitude_mask)
342
+ active_stability = self._jaccard(active, self.prev_active)
343
+
344
+ self.prev_active = active.clone()
345
+
346
+ k20 = max(1, int(0.2 * len(self.blocks)))
347
+
348
+ sorted_residual = torch.sort(new_residual_mass.detach(), descending=True).values
349
+ top20_residual_mass = (sorted_residual[:k20].sum() / (sorted_residual.sum() + self.eps)).item()
350
+
351
+ sorted_gradient = torch.sort(new_gradient_mass.detach(), descending=True).values
352
+ top20_gradient_mass = (sorted_gradient[:k20].sum() / (sorted_gradient.sum() + self.eps)).item()
353
+
354
+ # Update mass trackers AFTER diagnostics, so predicted_magnitude really
355
+ # uses only history at selection time.
356
+ self.current_gradient_mass = new_gradient_mass
357
+ self.current_residual_mass = new_residual_mass
358
+ self.predicted_gradient_mass.mul_(self.mass_beta).add_(new_gradient_mass, alpha=1.0 - self.mass_beta)
359
+
360
+ return {
361
+ "is_refresh": float(is_refresh),
362
+ "active_fraction": float(active.float().mean().item()),
363
+ "cosine_true_vs_approx": cosine,
364
+ "pred_explained_fraction": pred_explained,
365
+ "top20_residual_mass": top20_residual_mass,
366
+ "top20_gradient_mass": top20_gradient_mass,
367
+ "active_vs_oracle_mag": active_vs_oracle_mag,
368
+ "active_vs_oracle_resid": active_vs_oracle_resid,
369
+ "predmag_vs_oracle_mag": predmag_vs_oracle_mag,
370
+ "active_stability": active_stability,
371
+ "true_grad_norm": float(true_norm.item()),
372
+ "pred_error_norm": float(pred_error_norm.item()),
373
+ }
374
+
375
+
376
+ # -----------------------------
377
+ # Metrics and training
378
+ # -----------------------------
379
+
380
+ def accuracy(model: nn.Module, X: torch.Tensor, y: torch.Tensor) -> float:
381
+ model.eval()
382
+ with torch.no_grad():
383
+ pred = model(X).argmax(dim=1)
384
+ return (pred == y).float().mean().item()
385
+
386
+
387
+ def train_baseline(
388
+ X: torch.Tensor,
389
+ y: torch.Tensor,
390
+ steps: int = 2000,
391
+ batch_size: int = 256,
392
+ lr: float = 1e-3,
393
+ ) -> Tuple[nn.Module, List[Dict[str, float]]]:
394
+ model = TinyMLP().to(DEVICE)
395
+ opt = torch.optim.Adam(model.parameters(), lr=lr)
396
+ history: List[Dict[str, float]] = []
397
+
398
+ for step in range(steps):
399
+ model.train()
400
+ idx = torch.randint(0, X.shape[0], (batch_size,), device=DEVICE)
401
+ xb, yb = X[idx], y[idx]
402
+
403
+ loss = F.cross_entropy(model(xb), yb)
404
+
405
+ opt.zero_grad(set_to_none=True)
406
+ loss.backward()
407
+ opt.step()
408
+
409
+ if step % 97 == 0 or step == steps - 1:
410
+ history.append({
411
+ "step": step,
412
+ "loss": float(loss.item()),
413
+ "accuracy": accuracy(model, X, y),
414
+ })
415
+
416
+ return model, history
417
+
418
+
419
+ def train_sparse_policy(
420
+ X: torch.Tensor,
421
+ y: torch.Tensor,
422
+ policy: Policy,
423
+ steps: int = 2000,
424
+ batch_size: int = 256,
425
+ lr: float = 1e-3,
426
+ active_fraction: float = 0.2,
427
+ refresh_interval: int = 10,
428
+ warmup_steps: int = 100,
429
+ grad_beta: float = 0.95,
430
+ mass_beta: float = 0.95,
431
+ ) -> Tuple[nn.Module, List[Dict[str, float]]]:
432
+ model = TinyMLP().to(DEVICE)
433
+ opt = torch.optim.Adam(model.parameters(), lr=lr)
434
+ builder = SparseGradientBuilder(
435
+ model,
436
+ policy=policy,
437
+ grad_beta=grad_beta,
438
+ mass_beta=mass_beta,
439
+ active_fraction=active_fraction,
440
+ refresh_interval=refresh_interval,
441
+ warmup_steps=warmup_steps,
442
+ )
443
+
444
+ history: List[Dict[str, float]] = []
445
+
446
+ for step in range(steps):
447
+ model.train()
448
+ idx = torch.randint(0, X.shape[0], (batch_size,), device=DEVICE)
449
+ xb, yb = X[idx], y[idx]
450
+
451
+ loss = F.cross_entropy(model(xb), yb)
452
+
453
+ opt.zero_grad(set_to_none=True)
454
+ loss.backward()
455
+
456
+ diagnostics = builder.build_and_install_approx_grads(step)
457
+ opt.step()
458
+
459
+ if step % 97 == 0 or step == steps - 1:
460
+ history.append({
461
+ "step": step,
462
+ "loss": float(loss.item()),
463
+ "accuracy": accuracy(model, X, y),
464
+ **diagnostics,
465
+ })
466
+
467
+ return model, history
468
+
469
+
470
+ def sparse_rows(history: List[Dict[str, float]]) -> List[Dict[str, float]]:
471
+ return [row for row in history if row.get("is_refresh", 0.0) == 0.0]
472
+
473
+
474
+ def avg_sparse_metric(history: List[Dict[str, float]], key: str) -> float:
475
+ rows = sparse_rows(history)
476
+ vals = [row[key] for row in rows]
477
+ if not vals:
478
+ return float("nan")
479
+ return sum(vals) / len(vals)
480
+
481
+
482
+ def print_last(label: str, history: List[Dict[str, float]]) -> None:
483
+ print(f"\n{label}")
484
+ for k, v in history[-1].items():
485
+ if isinstance(v, float):
486
+ print(f" {k:28s}: {v:.4f}")
487
+ else:
488
+ print(f" {k:28s}: {v}")
489
+
490
+
491
+ def print_sparse_summary(label: str, history: List[Dict[str, float]]) -> None:
492
+ last = history[-1]
493
+ print(f"\n{label} sparse-step averages from logged checkpoints")
494
+ for key in [
495
+ "cosine_true_vs_approx",
496
+ "pred_explained_fraction",
497
+ "top20_residual_mass",
498
+ "top20_gradient_mass",
499
+ "active_vs_oracle_mag",
500
+ "active_vs_oracle_resid",
501
+ "predmag_vs_oracle_mag",
502
+ "active_stability",
503
+ ]:
504
+ print(f" avg {key:24s}: {avg_sparse_metric(history, key):.4f}")
505
+ print(f" final accuracy : {last['accuracy']:.4f}")
506
+ print(f" final loss : {last['loss']:.4f}")
507
+
508
+
509
+ def print_checkpoints(label: str, history: List[Dict[str, float]]) -> None:
510
+ print(f"\n{label} checkpoints:")
511
+ stride = max(1, len(history) // 8)
512
+ for row in history[::stride]:
513
+ extra = ""
514
+ if "cosine_true_vs_approx" in row:
515
+ extra = (
516
+ f" refresh={int(row['is_refresh'])}"
517
+ f" active={row['active_fraction']:.2f}"
518
+ f" cos={row['cosine_true_vs_approx']:.3f}"
519
+ f" pred_expl={row['pred_explained_fraction']:.3f}"
520
+ f" top20_grad={row['top20_gradient_mass']:.3f}"
521
+ f" jacc_mag={row['active_vs_oracle_mag']:.3f}"
522
+ f" stable={row['active_stability']:.3f}"
523
+ )
524
+ print(
525
+ f"step={row['step']:4d} "
526
+ f"loss={row['loss']:.4f} "
527
+ f"acc={row['accuracy']:.3f}"
528
+ f"{extra}"
529
+ )
530
+
531
+
532
+ def main() -> None:
533
+ X, y = make_spiral()
534
+
535
+ baseline_model, baseline_hist = train_baseline(X, y)
536
+
537
+ results = {}
538
+ for policy in ["surprise", "magnitude", "predicted_magnitude", "random"]:
539
+ _, hist = train_sparse_policy(
540
+ X,
541
+ y,
542
+ policy=policy,
543
+ active_fraction=0.2,
544
+ refresh_interval=10,
545
+ warmup_steps=100,
546
+ grad_beta=0.95,
547
+ mass_beta=0.95,
548
+ )
549
+ results[policy] = hist
550
+
551
+ print_last("Baseline full Adam", baseline_hist)
552
+ for policy, hist in results.items():
553
+ print_last(f"{policy.title().replace('_', ' ')} Top-K simulated Adam", hist)
554
+
555
+ print_checkpoints("Baseline", baseline_hist)
556
+ for policy, hist in results.items():
557
+ print_checkpoints(f"{policy.title().replace('_', ' ')} Top-K", hist)
558
+ print_sparse_summary(f"{policy.title().replace('_', ' ')} Top-K", hist)
559
+
560
+ print("\nHow to read this:")
561
+ print(" predicted_magnitude => the practical policy to watch")
562
+ print(" magnitude => oracle-ish upper bound using recent/current mass")
563
+ print(" cos near 1.0 => approximate update points like the true gradient")
564
+ print(" top20_grad high => raw gradient mass is heavy-tailed/concentrated")
565
+ print(" jacc_mag high => selected blocks match current oracle top-k blocks")
566
+ print(" stable high => active set is stable over time")
567
+ print(" pred_mag close to mag => historical mass is enough to select useful blocks")
568
+
569
+
570
+ if __name__ == "__main__":
571
+ main()
experiments/surprise_topk_gradient_prototype-v5.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Predicted-Magnitude Sparse Gradient Prototype, v5
3
+
4
+ Goal
5
+ ----
6
+ Test the practical version of the hypothesis more harshly:
7
+
8
+ The gradient/update signal is heavy-tailed, and the high-mass blocks are
9
+ predictable enough that we can update only those blocks most of the time.
10
+
11
+ v5 adds
12
+ -------
13
+ 1. inactive_mode="ema"
14
+ Inactive blocks receive the EMA-predicted gradient.
15
+
16
+ 2. inactive_mode="zero"
17
+ Inactive blocks receive zero gradient. This is the stricter and more
18
+ compute-relevant test.
19
+
20
+ 3. active_fraction sweep
21
+ Tests 20%, 10%, 5%, and 2% active blocks.
22
+
23
+ 4. focused policy comparison
24
+ - predicted_magnitude: practical policy, chooses active blocks from history
25
+ - magnitude: oracle-ish policy, chooses from recently observed mass
26
+ - random: control
27
+
28
+ Important caveat
29
+ ----------------
30
+ This still computes the full gradient every step. That is intentional. We are
31
+ measuring whether the selected active set would have preserved useful learning.
32
+ Actual speedup would require restricting/skipping backward computation for
33
+ inactive blocks with custom structured backward logic.
34
+
35
+ Run
36
+ ---
37
+ python3 surprise_topk_gradient_prototype.py
38
+ """
39
+
40
+ from __future__ import annotations
41
+
42
+ import math
43
+ import random
44
+ from dataclasses import dataclass
45
+ from typing import Dict, List, Literal, Tuple
46
+
47
+ import torch
48
+ import torch.nn as nn
49
+ import torch.nn.functional as F
50
+
51
+
52
+ SEED = 7
53
+ random.seed(SEED)
54
+ torch.manual_seed(SEED)
55
+
56
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
57
+ Policy = Literal["magnitude", "predicted_magnitude", "random"]
58
+ InactiveMode = Literal["ema", "zero"]
59
+
60
+
61
+ # -----------------------------
62
+ # Toy data: 2-class spiral
63
+ # -----------------------------
64
+
65
+ def make_spiral(n_per_class: int = 1024, noise: float = 0.12) -> Tuple[torch.Tensor, torch.Tensor]:
66
+ xs = []
67
+ ys = []
68
+
69
+ for class_id in range(2):
70
+ r = torch.linspace(0.0, 1.0, n_per_class)
71
+ theta = class_id * math.pi + r * 4.0 * math.pi
72
+ theta = theta + torch.randn(n_per_class) * noise
73
+
74
+ x = torch.stack([r * torch.sin(theta), r * torch.cos(theta)], dim=1)
75
+ y = torch.full((n_per_class,), class_id, dtype=torch.long)
76
+
77
+ xs.append(x)
78
+ ys.append(y)
79
+
80
+ X = torch.cat(xs, dim=0)
81
+ Y = torch.cat(ys, dim=0)
82
+ X = 3.0 * X
83
+
84
+ perm = torch.randperm(X.shape[0])
85
+ return X[perm].to(DEVICE), Y[perm].to(DEVICE)
86
+
87
+
88
+ # -----------------------------
89
+ # Model
90
+ # -----------------------------
91
+
92
+ class TinyMLP(nn.Module):
93
+ def __init__(self, width: int = 128):
94
+ super().__init__()
95
+ self.net = nn.Sequential(
96
+ nn.Linear(2, width),
97
+ nn.ReLU(),
98
+ nn.Linear(width, width),
99
+ nn.ReLU(),
100
+ nn.Linear(width, width),
101
+ nn.ReLU(),
102
+ nn.Linear(width, 2),
103
+ )
104
+
105
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
106
+ return self.net(x)
107
+
108
+
109
+ def linear_layers(model: nn.Module) -> List[nn.Linear]:
110
+ return [m for m in model.modules() if isinstance(m, nn.Linear)]
111
+
112
+
113
+ # -----------------------------
114
+ # Sparse gradient machinery
115
+ # -----------------------------
116
+
117
+ @dataclass(frozen=True)
118
+ class BlockRef:
119
+ layer_index: int
120
+ row_index: int
121
+
122
+
123
+ class SparseGradientBuilder:
124
+ """
125
+ Builds approximate gradients after a full backward pass.
126
+
127
+ Block = one output row of a Linear weight matrix, plus its bias element.
128
+
129
+ Approx gradient:
130
+ active blocks -> true gradient
131
+ inactive blocks -> either EMA-predicted gradient or zero gradient
132
+
133
+ Selection policies:
134
+ magnitude -> largest recently observed gradient norm
135
+ predicted_magnitude -> largest historical EMA gradient norm
136
+ random -> random active blocks
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ model: nn.Module,
142
+ policy: Policy = "predicted_magnitude",
143
+ inactive_mode: InactiveMode = "zero",
144
+ grad_beta: float = 0.95,
145
+ mass_beta: float = 0.95,
146
+ active_fraction: float = 0.2,
147
+ refresh_interval: int = 10,
148
+ warmup_steps: int = 100,
149
+ eps: float = 1e-12,
150
+ ):
151
+ self.model = model
152
+ self.layers = linear_layers(model)
153
+ self.policy = policy
154
+ self.inactive_mode = inactive_mode
155
+ self.grad_beta = grad_beta
156
+ self.mass_beta = mass_beta
157
+ self.active_fraction = active_fraction
158
+ self.refresh_interval = refresh_interval
159
+ self.warmup_steps = warmup_steps
160
+ self.eps = eps
161
+
162
+ self.blocks: List[BlockRef] = []
163
+ for li, layer in enumerate(self.layers):
164
+ for row in range(layer.weight.shape[0]):
165
+ self.blocks.append(BlockRef(li, row))
166
+
167
+ self.pred_w: Dict[int, torch.Tensor] = {}
168
+ self.pred_b: Dict[int, torch.Tensor] = {}
169
+
170
+ for li, layer in enumerate(self.layers):
171
+ self.pred_w[li] = torch.zeros_like(layer.weight.data)
172
+ if layer.bias is not None:
173
+ self.pred_b[li] = torch.zeros_like(layer.bias.data)
174
+
175
+ n = len(self.blocks)
176
+ self.current_gradient_mass = torch.ones(n, device=DEVICE)
177
+ self.predicted_gradient_mass = torch.ones(n, device=DEVICE)
178
+ self.prev_active = torch.zeros(n, dtype=torch.bool, device=DEVICE)
179
+
180
+ def _is_refresh_step(self, step: int) -> bool:
181
+ return step < self.warmup_steps or step % self.refresh_interval == 0
182
+
183
+ def _choose_active_blocks(self, step: int) -> torch.Tensor:
184
+ n = len(self.blocks)
185
+
186
+ if self._is_refresh_step(step):
187
+ return torch.ones(n, dtype=torch.bool, device=DEVICE)
188
+
189
+ k = max(1, int(self.active_fraction * n))
190
+ active = torch.zeros(n, dtype=torch.bool, device=DEVICE)
191
+
192
+ if self.policy == "magnitude":
193
+ # Uses the previous step's observed gradient mass, not the current one.
194
+ idx = torch.topk(self.current_gradient_mass, k=k).indices
195
+ elif self.policy == "predicted_magnitude":
196
+ # Practical policy: uses historical EMA only.
197
+ idx = torch.topk(self.predicted_gradient_mass, k=k).indices
198
+ elif self.policy == "random":
199
+ idx = torch.randperm(n, device=DEVICE)[:k]
200
+ else:
201
+ raise ValueError(f"Unknown policy: {self.policy}")
202
+
203
+ active[idx] = True
204
+ return active
205
+
206
+ @staticmethod
207
+ def _topk_mask(values: torch.Tensor, fraction: float) -> torch.Tensor:
208
+ n = values.numel()
209
+ k = max(1, int(fraction * n))
210
+ mask = torch.zeros(n, dtype=torch.bool, device=values.device)
211
+ mask[torch.topk(values, k=k).indices] = True
212
+ return mask
213
+
214
+ @staticmethod
215
+ def _jaccard(a: torch.Tensor, b: torch.Tensor) -> float:
216
+ inter = (a & b).sum().float()
217
+ union = (a | b).sum().float()
218
+ return float((inter / torch.clamp(union, min=1.0)).item())
219
+
220
+ @torch.no_grad()
221
+ def build_and_install_approx_grads(self, step: int) -> Dict[str, float]:
222
+ active = self._choose_active_blocks(step)
223
+ is_refresh = self._is_refresh_step(step)
224
+
225
+ true_parts = []
226
+ approx_parts = []
227
+ pred_parts = []
228
+
229
+ approx_w: Dict[int, torch.Tensor] = {}
230
+ approx_b: Dict[int, torch.Tensor] = {}
231
+ for li, layer in enumerate(self.layers):
232
+ approx_w[li] = torch.zeros_like(layer.weight.grad)
233
+ if layer.bias is not None:
234
+ approx_b[li] = torch.zeros_like(layer.bias.grad)
235
+
236
+ new_gradient_mass = torch.zeros_like(self.current_gradient_mass)
237
+
238
+ for block_id, block in enumerate(self.blocks):
239
+ li = block.layer_index
240
+ row = block.row_index
241
+ layer = self.layers[li]
242
+ is_active = bool(active[block_id].item())
243
+
244
+ g_w = layer.weight.grad[row].detach().clone()
245
+ p_w = self.pred_w[li][row].detach().clone()
246
+
247
+ if layer.bias is not None:
248
+ g_b = layer.bias.grad[row].detach().clone()
249
+ p_b = self.pred_b[li][row].detach().clone()
250
+ else:
251
+ g_b = None
252
+ p_b = None
253
+
254
+ true_vec_items = [g_w.flatten()]
255
+ pred_vec_items = [p_w.flatten()]
256
+ if g_b is not None:
257
+ true_vec_items.append(g_b.view(1))
258
+ pred_vec_items.append(p_b.view(1))
259
+
260
+ true_vec_block = torch.cat(true_vec_items)
261
+ pred_vec_block = torch.cat(pred_vec_items)
262
+ grad_norm = torch.norm(true_vec_block)
263
+ new_gradient_mass[block_id] = grad_norm
264
+
265
+ if is_active:
266
+ a_w = g_w
267
+ a_b = g_b
268
+ else:
269
+ if self.inactive_mode == "ema":
270
+ a_w = p_w
271
+ a_b = p_b
272
+ elif self.inactive_mode == "zero":
273
+ a_w = torch.zeros_like(g_w)
274
+ a_b = torch.zeros_like(g_b) if g_b is not None else None
275
+ else:
276
+ raise ValueError(f"Unknown inactive_mode: {self.inactive_mode}")
277
+
278
+ approx_w[li][row] = a_w
279
+ if layer.bias is not None:
280
+ approx_b[li][row] = a_b
281
+
282
+ # Prototype choice: update predictors from true gradient because the
283
+ # full gradient is available for measurement. A speedup version would
284
+ # update predictors fully only on refresh steps and partially on active blocks.
285
+ self.pred_w[li][row].mul_(self.grad_beta).add_(g_w, alpha=1.0 - self.grad_beta)
286
+ if layer.bias is not None:
287
+ self.pred_b[li][row].mul_(self.grad_beta).add_(g_b, alpha=1.0 - self.grad_beta)
288
+
289
+ approx_vec_items = [a_w.flatten()]
290
+ if a_b is not None:
291
+ approx_vec_items.append(a_b.view(1))
292
+
293
+ true_parts.append(true_vec_block)
294
+ pred_parts.append(pred_vec_block)
295
+ approx_parts.append(torch.cat(approx_vec_items))
296
+
297
+ # Install approximate gradients for Adam.
298
+ for li, layer in enumerate(self.layers):
299
+ layer.weight.grad.copy_(approx_w[li])
300
+ if layer.bias is not None:
301
+ layer.bias.grad.copy_(approx_b[li])
302
+
303
+ true_vec = torch.cat(true_parts)
304
+ pred_vec = torch.cat(pred_parts)
305
+ approx_vec = torch.cat(approx_parts)
306
+
307
+ true_norm = torch.norm(true_vec)
308
+ pred_error_norm = torch.norm(true_vec - pred_vec)
309
+
310
+ cosine = F.cosine_similarity(true_vec, approx_vec, dim=0).item()
311
+ approx_norm_ratio = float((torch.norm(approx_vec) / (true_norm + self.eps)).item())
312
+ pred_explained = 1.0 - (
313
+ pred_error_norm.pow(2) / (true_norm.pow(2) + self.eps)
314
+ ).item()
315
+
316
+ oracle_magnitude_mask = self._topk_mask(new_gradient_mass, self.active_fraction)
317
+ predicted_magnitude_mask = self._topk_mask(self.predicted_gradient_mass, self.active_fraction)
318
+
319
+ active_vs_oracle_mag = self._jaccard(active, oracle_magnitude_mask)
320
+ predmag_vs_oracle_mag = self._jaccard(predicted_magnitude_mask, oracle_magnitude_mask)
321
+ active_stability = self._jaccard(active, self.prev_active)
322
+ self.prev_active = active.clone()
323
+
324
+ k20 = max(1, int(0.2 * len(self.blocks)))
325
+ sorted_gradient = torch.sort(new_gradient_mass.detach(), descending=True).values
326
+ top20_gradient_mass = (sorted_gradient[:k20].sum() / (sorted_gradient.sum() + self.eps)).item()
327
+
328
+ # Update mass trackers AFTER diagnostics, so predicted_magnitude really
329
+ # uses only history at selection time.
330
+ self.current_gradient_mass = new_gradient_mass
331
+ self.predicted_gradient_mass.mul_(self.mass_beta).add_(new_gradient_mass, alpha=1.0 - self.mass_beta)
332
+
333
+ return {
334
+ "is_refresh": float(is_refresh),
335
+ "active_fraction": float(active.float().mean().item()),
336
+ "cosine_true_vs_approx": cosine,
337
+ "approx_norm_ratio": approx_norm_ratio,
338
+ "pred_explained_fraction": pred_explained,
339
+ "top20_gradient_mass": top20_gradient_mass,
340
+ "active_vs_oracle_mag": active_vs_oracle_mag,
341
+ "predmag_vs_oracle_mag": predmag_vs_oracle_mag,
342
+ "active_stability": active_stability,
343
+ "true_grad_norm": float(true_norm.item()),
344
+ "pred_error_norm": float(pred_error_norm.item()),
345
+ }
346
+
347
+
348
+ # -----------------------------
349
+ # Metrics and training
350
+ # -----------------------------
351
+
352
+ def accuracy(model: nn.Module, X: torch.Tensor, y: torch.Tensor) -> float:
353
+ model.eval()
354
+ with torch.no_grad():
355
+ pred = model(X).argmax(dim=1)
356
+ return (pred == y).float().mean().item()
357
+
358
+
359
+ def train_baseline(
360
+ X: torch.Tensor,
361
+ y: torch.Tensor,
362
+ steps: int = 2000,
363
+ batch_size: int = 256,
364
+ lr: float = 1e-3,
365
+ ) -> Tuple[nn.Module, List[Dict[str, float]]]:
366
+ model = TinyMLP().to(DEVICE)
367
+ opt = torch.optim.Adam(model.parameters(), lr=lr)
368
+ history: List[Dict[str, float]] = []
369
+
370
+ for step in range(steps):
371
+ model.train()
372
+ idx = torch.randint(0, X.shape[0], (batch_size,), device=DEVICE)
373
+ xb, yb = X[idx], y[idx]
374
+
375
+ loss = F.cross_entropy(model(xb), yb)
376
+
377
+ opt.zero_grad(set_to_none=True)
378
+ loss.backward()
379
+ opt.step()
380
+
381
+ if step % 97 == 0 or step == steps - 1:
382
+ history.append({
383
+ "step": step,
384
+ "loss": float(loss.item()),
385
+ "accuracy": accuracy(model, X, y),
386
+ })
387
+
388
+ return model, history
389
+
390
+
391
+ def train_sparse_policy(
392
+ X: torch.Tensor,
393
+ y: torch.Tensor,
394
+ policy: Policy,
395
+ inactive_mode: InactiveMode,
396
+ active_fraction: float,
397
+ steps: int = 2000,
398
+ batch_size: int = 256,
399
+ lr: float = 1e-3,
400
+ refresh_interval: int = 10,
401
+ warmup_steps: int = 100,
402
+ grad_beta: float = 0.95,
403
+ mass_beta: float = 0.95,
404
+ ) -> Tuple[nn.Module, List[Dict[str, float]]]:
405
+ model = TinyMLP().to(DEVICE)
406
+ opt = torch.optim.Adam(model.parameters(), lr=lr)
407
+ builder = SparseGradientBuilder(
408
+ model,
409
+ policy=policy,
410
+ inactive_mode=inactive_mode,
411
+ grad_beta=grad_beta,
412
+ mass_beta=mass_beta,
413
+ active_fraction=active_fraction,
414
+ refresh_interval=refresh_interval,
415
+ warmup_steps=warmup_steps,
416
+ )
417
+
418
+ history: List[Dict[str, float]] = []
419
+
420
+ for step in range(steps):
421
+ model.train()
422
+ idx = torch.randint(0, X.shape[0], (batch_size,), device=DEVICE)
423
+ xb, yb = X[idx], y[idx]
424
+
425
+ loss = F.cross_entropy(model(xb), yb)
426
+
427
+ opt.zero_grad(set_to_none=True)
428
+ loss.backward()
429
+
430
+ diagnostics = builder.build_and_install_approx_grads(step)
431
+ opt.step()
432
+
433
+ if step % 97 == 0 or step == steps - 1:
434
+ history.append({
435
+ "step": step,
436
+ "loss": float(loss.item()),
437
+ "accuracy": accuracy(model, X, y),
438
+ **diagnostics,
439
+ })
440
+
441
+ return model, history
442
+
443
+
444
+ def sparse_rows(history: List[Dict[str, float]]) -> List[Dict[str, float]]:
445
+ return [row for row in history if row.get("is_refresh", 0.0) == 0.0]
446
+
447
+
448
+ def avg_sparse_metric(history: List[Dict[str, float]], key: str) -> float:
449
+ rows = sparse_rows(history)
450
+ vals = [row[key] for row in rows]
451
+ if not vals:
452
+ return float("nan")
453
+ return sum(vals) / len(vals)
454
+
455
+
456
+ def summarize_sparse_run(
457
+ policy: Policy,
458
+ inactive_mode: InactiveMode,
459
+ active_fraction: float,
460
+ history: List[Dict[str, float]],
461
+ ) -> Dict[str, float | str]:
462
+ last = history[-1]
463
+ return {
464
+ "policy": policy,
465
+ "inactive_mode": inactive_mode,
466
+ "active_fraction": active_fraction,
467
+ "final_accuracy": last["accuracy"],
468
+ "final_loss": last["loss"],
469
+ "avg_cosine": avg_sparse_metric(history, "cosine_true_vs_approx"),
470
+ "avg_norm_ratio": avg_sparse_metric(history, "approx_norm_ratio"),
471
+ "avg_top20_grad_mass": avg_sparse_metric(history, "top20_gradient_mass"),
472
+ "avg_jacc_mag": avg_sparse_metric(history, "active_vs_oracle_mag"),
473
+ "avg_predmag_jacc": avg_sparse_metric(history, "predmag_vs_oracle_mag"),
474
+ "avg_stability": avg_sparse_metric(history, "active_stability"),
475
+ }
476
+
477
+
478
+ def print_baseline(history: List[Dict[str, float]]) -> None:
479
+ print("\nBaseline full Adam")
480
+ for k, v in history[-1].items():
481
+ if isinstance(v, float):
482
+ print(f" {k:24s}: {v:.4f}")
483
+ else:
484
+ print(f" {k:24s}: {v}")
485
+
486
+
487
+ def print_summary_table(rows: List[Dict[str, float | str]]) -> None:
488
+ print("\nSparse run summary")
489
+ header = (
490
+ f"{'policy':>20s} {'mode':>6s} {'active':>7s} "
491
+ f"{'acc':>7s} {'loss':>9s} {'cos':>7s} {'norm':>7s} "
492
+ f"{'top20':>7s} {'jacc':>7s} {'stable':>7s}"
493
+ )
494
+ print(header)
495
+ print("-" * len(header))
496
+
497
+ for row in rows:
498
+ print(
499
+ f"{str(row['policy']):>20s} "
500
+ f"{str(row['inactive_mode']):>6s} "
501
+ f"{float(row['active_fraction']):7.2f} "
502
+ f"{float(row['final_accuracy']):7.4f} "
503
+ f"{float(row['final_loss']):9.4f} "
504
+ f"{float(row['avg_cosine']):7.3f} "
505
+ f"{float(row['avg_norm_ratio']):7.3f} "
506
+ f"{float(row['avg_top20_grad_mass']):7.3f} "
507
+ f"{float(row['avg_jacc_mag']):7.3f} "
508
+ f"{float(row['avg_stability']):7.3f}"
509
+ )
510
+
511
+
512
+ def main() -> None:
513
+ X, y = make_spiral()
514
+
515
+ baseline_model, baseline_hist = train_baseline(X, y)
516
+ print_baseline(baseline_hist)
517
+
518
+ active_fractions = [0.20, 0.10, 0.05, 0.02]
519
+
520
+ # Keep this matrix focused; it runs 20 sparse experiments.
521
+ experiment_plan: List[Tuple[Policy, InactiveMode, float]] = []
522
+
523
+ for active_fraction in active_fractions:
524
+ experiment_plan.append(("predicted_magnitude", "ema", active_fraction))
525
+ experiment_plan.append(("predicted_magnitude", "zero", active_fraction))
526
+ experiment_plan.append(("magnitude", "zero", active_fraction))
527
+ experiment_plan.append(("random", "zero", active_fraction))
528
+
529
+ summary_rows: List[Dict[str, float | str]] = []
530
+
531
+ for policy, inactive_mode, active_fraction in experiment_plan:
532
+ print(
533
+ f"\nRunning policy={policy}, inactive_mode={inactive_mode}, "
534
+ f"active_fraction={active_fraction:.2f}"
535
+ )
536
+ _, hist = train_sparse_policy(
537
+ X,
538
+ y,
539
+ policy=policy,
540
+ inactive_mode=inactive_mode,
541
+ active_fraction=active_fraction,
542
+ refresh_interval=10,
543
+ warmup_steps=100,
544
+ grad_beta=0.95,
545
+ mass_beta=0.95,
546
+ )
547
+ summary_rows.append(summarize_sparse_run(policy, inactive_mode, active_fraction, hist))
548
+
549
+ print_summary_table(summary_rows)
550
+
551
+ print("\nHow to read this:")
552
+ print(" predicted_magnitude + zero is the main practical test.")
553
+ print(" magnitude + zero is an oracle-ish upper bound using recent observed mass.")
554
+ print(" random + zero is the control.")
555
+ print(" acc close to baseline means sparse updates preserved learning.")
556
+ print(" cos near 1.0 means sparse update direction matches full gradient direction.")
557
+ print(" norm much below 1.0 means the sparse update is much smaller than full gradient.")
558
+ print(" top20 near 0.7+ means gradient mass is concentrated/heavy-tailed.")
559
+ print(" jacc above random means active-set prediction finds the true important blocks.")
560
+
561
+
562
+ if __name__ == "__main__":
563
+ main()
experiments/surprise_topk_gradient_prototype.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Surprise Top-K Gradient Prototype
3
+
4
+ This is a deliberately small, readable PyTorch experiment for testing the idea:
5
+
6
+ gradient_t ≈ predicted_gradient_t + sparse_surprising_residual_t
7
+
8
+ The code compares:
9
+ 1. Baseline full SGD
10
+ 2. SurpriseTopK training, where only high-surprise parameter blocks use the true
11
+ gradient on cheap steps; low-surprise blocks use a predicted/stale gradient.
12
+
13
+ Important caveat:
14
+ This prototype still computes full gradients on every step so we can evaluate the
15
+ approximation honestly. It simulates reduced backward/update entropy; it does not
16
+ yet provide real wall-clock acceleration. Real acceleration would require structured
17
+ partial backward passes, custom kernels, or graph-level masking.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import math
23
+ import random
24
+ from dataclasses import dataclass
25
+ from typing import Dict, List, Tuple
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+
31
+
32
+ # -----------------------------
33
+ # Reproducibility
34
+ # -----------------------------
35
+
36
+ SEED = 7
37
+ random.seed(SEED)
38
+ torch.manual_seed(SEED)
39
+
40
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
41
+
42
+
43
+ # -----------------------------
44
+ # Toy data: 2-class spiral
45
+ # -----------------------------
46
+
47
+ def make_spiral(n_per_class: int = 512, noise: float = 0.2) -> Tuple[torch.Tensor, torch.Tensor]:
48
+ """Create a small nonlinear classification problem without external datasets."""
49
+ xs = []
50
+ ys = []
51
+
52
+ for class_id in range(2):
53
+ r = torch.linspace(0.0, 1.0, n_per_class)
54
+ theta = class_id * math.pi + r * 4.0 * math.pi
55
+ theta = theta + torch.randn(n_per_class) * noise
56
+
57
+ x = torch.stack([r * torch.sin(theta), r * torch.cos(theta)], dim=1)
58
+ y = torch.full((n_per_class,), class_id, dtype=torch.long)
59
+
60
+ xs.append(x)
61
+ ys.append(y)
62
+
63
+ X = torch.cat(xs, dim=0)
64
+ Y = torch.cat(ys, dim=0)
65
+
66
+ perm = torch.randperm(X.shape[0])
67
+ return X[perm].to(DEVICE), Y[perm].to(DEVICE)
68
+
69
+
70
+ # -----------------------------
71
+ # Model
72
+ # -----------------------------
73
+
74
+ class TinyMLP(nn.Module):
75
+ def __init__(self, width: int = 64):
76
+ super().__init__()
77
+ self.net = nn.Sequential(
78
+ nn.Linear(2, width),
79
+ nn.Tanh(),
80
+ nn.Linear(width, width),
81
+ nn.Tanh(),
82
+ nn.Linear(width, 2),
83
+ )
84
+
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ return self.net(x)
87
+
88
+
89
+ def linear_layers(model: nn.Module) -> List[nn.Linear]:
90
+ return [m for m in model.modules() if isinstance(m, nn.Linear)]
91
+
92
+
93
+ # -----------------------------
94
+ # Block bookkeeping
95
+ # -----------------------------
96
+
97
+ @dataclass
98
+ class BlockRef:
99
+ """A block is one output row of a Linear layer, plus its bias element if present."""
100
+ layer_index: int
101
+ row_index: int
102
+
103
+
104
+ class SurpriseTopKUpdater:
105
+ """
106
+ Applies a predicted-gradient + top-surprise true-gradient update.
107
+
108
+ Unit of sparsity:
109
+ One output row of each Linear layer.
110
+
111
+ Why rows?
112
+ Row blocks correspond roughly to neurons/features, and structured blocks are
113
+ much closer to real hardware speedup than individual unstructured weights.
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ model: nn.Module,
119
+ lr: float = 0.05,
120
+ beta: float = 0.9,
121
+ active_fraction: float = 0.2,
122
+ refresh_interval: int = 10,
123
+ use_error_feedback: bool = True,
124
+ eps: float = 1e-12,
125
+ ):
126
+ self.model = model
127
+ self.layers = linear_layers(model)
128
+ self.lr = lr
129
+ self.beta = beta
130
+ self.active_fraction = active_fraction
131
+ self.refresh_interval = refresh_interval
132
+ self.use_error_feedback = use_error_feedback
133
+ self.eps = eps
134
+
135
+ self.blocks: List[BlockRef] = []
136
+ for li, layer in enumerate(self.layers):
137
+ for row in range(layer.weight.shape[0]):
138
+ self.blocks.append(BlockRef(li, row))
139
+
140
+ # Predicted gradients, same shape as parameters.
141
+ self.pred_w: Dict[int, torch.Tensor] = {}
142
+ self.pred_b: Dict[int, torch.Tensor] = {}
143
+
144
+ # Error-feedback buffers accumulate information we did not apply.
145
+ self.err_w: Dict[int, torch.Tensor] = {}
146
+ self.err_b: Dict[int, torch.Tensor] = {}
147
+
148
+ # Surprise scores per block. Higher means more worth computing/updating exactly.
149
+ self.scores = torch.ones(len(self.blocks), device=DEVICE)
150
+
151
+ for li, layer in enumerate(self.layers):
152
+ self.pred_w[li] = torch.zeros_like(layer.weight.data)
153
+ self.err_w[li] = torch.zeros_like(layer.weight.data)
154
+ if layer.bias is not None:
155
+ self.pred_b[li] = torch.zeros_like(layer.bias.data)
156
+ self.err_b[li] = torch.zeros_like(layer.bias.data)
157
+
158
+ def _block_grad_vector(self, li: int, row: int) -> torch.Tensor:
159
+ layer = self.layers[li]
160
+ parts = [layer.weight.grad[row].flatten()]
161
+ if layer.bias is not None:
162
+ parts.append(layer.bias.grad[row].view(1))
163
+ return torch.cat(parts)
164
+
165
+ def _block_pred_vector(self, li: int, row: int) -> torch.Tensor:
166
+ layer = self.layers[li]
167
+ parts = [self.pred_w[li][row].flatten()]
168
+ if layer.bias is not None:
169
+ parts.append(self.pred_b[li][row].view(1))
170
+ return torch.cat(parts)
171
+
172
+ def _choose_active_blocks(self, step: int) -> torch.Tensor:
173
+ """Return boolean mask over blocks."""
174
+ n_blocks = len(self.blocks)
175
+
176
+ # Full refresh: observe/update everything.
177
+ if step % self.refresh_interval == 0:
178
+ return torch.ones(n_blocks, dtype=torch.bool, device=DEVICE)
179
+
180
+ k = max(1, int(self.active_fraction * n_blocks))
181
+ active = torch.zeros(n_blocks, dtype=torch.bool, device=DEVICE)
182
+ top_idx = torch.topk(self.scores, k=k).indices
183
+ active[top_idx] = True
184
+ return active
185
+
186
+ @torch.no_grad()
187
+ def step(self, step: int) -> Dict[str, float]:
188
+ """
189
+ Apply one optimizer step after loss.backward().
190
+
191
+ Returns diagnostics comparing the approximate update to the true gradient.
192
+ """
193
+ active = self._choose_active_blocks(step)
194
+
195
+ true_flat = []
196
+ applied_flat = []
197
+ pred_flat = []
198
+
199
+ active_count = int(active.sum().item())
200
+ total_count = len(self.blocks)
201
+
202
+ for block_id, block in enumerate(self.blocks):
203
+ li = block.layer_index
204
+ row = block.row_index
205
+ layer = self.layers[li]
206
+ is_active = bool(active[block_id].item())
207
+
208
+ g_w = layer.weight.grad[row].clone()
209
+ g_b = layer.bias.grad[row].clone() if layer.bias is not None else None
210
+
211
+ # Error feedback makes skipped/prediction error come back later instead
212
+ # of disappearing forever.
213
+ if self.use_error_feedback:
214
+ g_w_eff = g_w + self.err_w[li][row]
215
+ g_b_eff = g_b + self.err_b[li][row] if g_b is not None else None
216
+ else:
217
+ g_w_eff = g_w
218
+ g_b_eff = g_b
219
+
220
+ p_w = self.pred_w[li][row].clone()
221
+ p_b = self.pred_b[li][row].clone() if layer.bias is not None else None
222
+
223
+ if is_active:
224
+ # Use the exact observed gradient for high-surprise blocks.
225
+ applied_w = g_w_eff
226
+ applied_b = g_b_eff
227
+
228
+ # Update predictor only where we pretend we actually observed the gradient.
229
+ self.pred_w[li][row].mul_(self.beta).add_(g_w, alpha=1.0 - self.beta)
230
+ if layer.bias is not None:
231
+ self.pred_b[li][row].mul_(self.beta).add_(g_b, alpha=1.0 - self.beta)
232
+
233
+ # Update surprise score: how wrong was the current predictor?
234
+ pred_vec = self._block_pred_vector(li, row)
235
+ grad_vec = self._block_grad_vector(li, row)
236
+ residual_norm = torch.norm(grad_vec - pred_vec)
237
+ grad_norm = torch.norm(grad_vec)
238
+ self.scores[block_id] = residual_norm / (grad_norm + self.eps)
239
+ else:
240
+ # Use the predicted/stale gradient for low-surprise blocks.
241
+ applied_w = p_w
242
+ applied_b = p_b
243
+
244
+ # Update error-feedback buffers.
245
+ if self.use_error_feedback:
246
+ self.err_w[li][row] = g_w_eff - applied_w
247
+ if layer.bias is not None:
248
+ self.err_b[li][row] = g_b_eff - applied_b
249
+
250
+ # Apply update.
251
+ layer.weight.data[row].add_(applied_w, alpha=-self.lr)
252
+ if layer.bias is not None:
253
+ layer.bias.data[row].add_(applied_b, alpha=-self.lr)
254
+
255
+ # Diagnostics.
256
+ true_parts = [g_w.flatten()]
257
+ app_parts = [applied_w.flatten()]
258
+ pred_parts = [p_w.flatten()]
259
+
260
+ if layer.bias is not None:
261
+ true_parts.append(g_b.view(1))
262
+ app_parts.append(applied_b.view(1))
263
+ pred_parts.append(p_b.view(1))
264
+
265
+ true_flat.append(torch.cat(true_parts))
266
+ applied_flat.append(torch.cat(app_parts))
267
+ pred_flat.append(torch.cat(pred_parts))
268
+
269
+ true_vec = torch.cat(true_flat)
270
+ applied_vec = torch.cat(applied_flat)
271
+ pred_vec = torch.cat(pred_flat)
272
+
273
+ cosine = F.cosine_similarity(true_vec, applied_vec, dim=0).item()
274
+ pred_explained = 1.0 - (
275
+ torch.norm(true_vec - pred_vec).pow(2) / (torch.norm(true_vec).pow(2) + self.eps)
276
+ ).item()
277
+
278
+ # Heavy-tail diagnostic: how much surprise mass lives in the top 20% blocks?
279
+ k20 = max(1, int(0.2 * total_count))
280
+ sorted_scores = torch.sort(self.scores.detach(), descending=True).values
281
+ top20_mass = (sorted_scores[:k20].sum() / (sorted_scores.sum() + self.eps)).item()
282
+
283
+ # Clear gradients manually.
284
+ for p in self.model.parameters():
285
+ p.grad = None
286
+
287
+ return {
288
+ "active_fraction": active_count / total_count,
289
+ "cosine_true_vs_applied": cosine,
290
+ "pred_explained_fraction": pred_explained,
291
+ "top20_surprise_mass": top20_mass,
292
+ }
293
+
294
+
295
+ # -----------------------------
296
+ # Training loops
297
+ # -----------------------------
298
+
299
+ def accuracy(model: nn.Module, X: torch.Tensor, y: torch.Tensor) -> float:
300
+ model.eval()
301
+ with torch.no_grad():
302
+ pred = model(X).argmax(dim=1)
303
+ return (pred == y).float().mean().item()
304
+
305
+
306
+ def train_baseline(
307
+ X: torch.Tensor,
308
+ y: torch.Tensor,
309
+ steps: int = 600,
310
+ batch_size: int = 128,
311
+ lr: float = 0.05,
312
+ ) -> Tuple[nn.Module, List[Dict[str, float]]]:
313
+ model = TinyMLP().to(DEVICE)
314
+ opt = torch.optim.SGD(model.parameters(), lr=lr)
315
+ history = []
316
+
317
+ for step in range(steps):
318
+ idx = torch.randint(0, X.shape[0], (batch_size,), device=DEVICE)
319
+ xb, yb = X[idx], y[idx]
320
+
321
+ logits = model(xb)
322
+ loss = F.cross_entropy(logits, yb)
323
+
324
+ opt.zero_grad(set_to_none=True)
325
+ loss.backward()
326
+ opt.step()
327
+
328
+ if step % 25 == 0 or step == steps - 1:
329
+ history.append({
330
+ "step": step,
331
+ "loss": float(loss.item()),
332
+ "accuracy": accuracy(model, X, y),
333
+ })
334
+
335
+ return model, history
336
+
337
+
338
+ def train_surprise_topk(
339
+ X: torch.Tensor,
340
+ y: torch.Tensor,
341
+ steps: int = 600,
342
+ batch_size: int = 128,
343
+ lr: float = 0.05,
344
+ active_fraction: float = 0.2,
345
+ refresh_interval: int = 10,
346
+ beta: float = 0.9,
347
+ ) -> Tuple[nn.Module, List[Dict[str, float]]]:
348
+ model = TinyMLP().to(DEVICE)
349
+ updater = SurpriseTopKUpdater(
350
+ model,
351
+ lr=lr,
352
+ beta=beta,
353
+ active_fraction=active_fraction,
354
+ refresh_interval=refresh_interval,
355
+ use_error_feedback=True,
356
+ )
357
+
358
+ history = []
359
+
360
+ for step in range(steps):
361
+ idx = torch.randint(0, X.shape[0], (batch_size,), device=DEVICE)
362
+ xb, yb = X[idx], y[idx]
363
+
364
+ logits = model(xb)
365
+ loss = F.cross_entropy(logits, yb)
366
+ loss.backward()
367
+
368
+ diagnostics = updater.step(step)
369
+
370
+ if step % 25 == 0 or step == steps - 1:
371
+ row = {
372
+ "step": step,
373
+ "loss": float(loss.item()),
374
+ "accuracy": accuracy(model, X, y),
375
+ **diagnostics,
376
+ }
377
+ history.append(row)
378
+
379
+ return model, history
380
+
381
+
382
+ # -----------------------------
383
+ # Main experiment
384
+ # -----------------------------
385
+
386
+ def print_last(label: str, history: List[Dict[str, float]]) -> None:
387
+ last = history[-1]
388
+ print(f"\n{label}")
389
+ for k, v in last.items():
390
+ if isinstance(v, float):
391
+ print(f" {k:28s}: {v:.4f}")
392
+ else:
393
+ print(f" {k:28s}: {v}")
394
+
395
+
396
+ def main() -> None:
397
+ X, y = make_spiral(n_per_class=768, noise=0.18)
398
+
399
+ baseline_model, baseline_hist = train_baseline(X, y)
400
+
401
+ surprise_model, surprise_hist = train_surprise_topk(
402
+ X,
403
+ y,
404
+ active_fraction=0.2,
405
+ refresh_interval=10,
406
+ beta=0.9,
407
+ )
408
+
409
+ print_last("Baseline full SGD", baseline_hist)
410
+ print_last("Surprise Top-K simulated training", surprise_hist)
411
+
412
+ print("\nA few Surprise Top-K checkpoints:")
413
+ for row in surprise_hist[:: max(1, len(surprise_hist) // 8)]:
414
+ print(
415
+ f"step={row['step']:4d} "
416
+ f"loss={row['loss']:.4f} "
417
+ f"acc={row['accuracy']:.3f} "
418
+ f"active={row['active_fraction']:.2f} "
419
+ f"cos={row['cosine_true_vs_applied']:.3f} "
420
+ f"pred_expl={row['pred_explained_fraction']:.3f} "
421
+ f"top20_mass={row['top20_surprise_mass']:.3f}"
422
+ )
423
+
424
+
425
+ if __name__ == "__main__":
426
+ main()
triton_sparse.py → experiments/triton_sparse.py RENAMED
File without changes
triton_v2.py → experiments/triton_v2.py RENAMED
File without changes
experiments/uv.lock ADDED
The diff for this file is too large to render. See raw diff
 
paper/main.tex ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ \documentclass[11pt]{article}
2
+
3
+ \usepackage[margin=1in]{geometry}
4
+ \usepackage{amsmath,amssymb}
5
+ \usepackage{graphicx}
6
+ \usepackage{microtype}
7
+ \usepackage{booktabs}
8
+ \usepackage{tabularx}
9
+ \usepackage{hyperref}
10
+
11
+ \hypersetup{
12
+ colorlinks=true,
13
+ linkcolor=blue,
14
+ citecolor=blue,
15
+ urlcolor=blue,
16
+ }
17
+
18
+ \title{%
19
+ \textbf{Zero-Copy Sparse Backpropagation}: \\[0.3em]
20
+ \large Temporal Gradient Tracking for
21
+ Faster, Regularized LLM Training
22
+ }
23
+
24
+ \author{
25
+ Daniel Owen van Dommelen\\
26
+ \textit{Independent Research - WORKING DRAFT}\\
27
+ \texttt{theapemachine@gmail.com}
28
+ }
29
+
30
+ \date{\today}
31
+
32
+ \begin{document}
33
+ \maketitle
34
+
35
+ \begin{abstract}
36
+ We describe \emph{Predictive Chunked Sparsity}: fixed top-$k$ row-chunks for
37
+ sparse $dW$, selected by an EMA of past chunk-gradient norms, with contiguous
38
+ slices for PyTorch-style GEMMs. On \textbf{Apple MPS} (full 6-layer runs:
39
+ $B{=}8$, $T{=}256$, chunk 64, $10\%$ active, 2000 steps), sparse training is
40
+ \textbf{slower} at $d_{\text{model}}{=}512$ ($\sim$1.22$\times$ higher ms/step than
41
+ dense for both $G_X$ modes) but \textbf{faster} at $d{=}2048$ ($\sim$1.18$\times$
42
+ and $\sim$1.221$\times$ speedup for full-$G_X$ and sparse-$G_X$ respectively,
43
+ with validation loss reported in Table~\ref{tab:mps-e2e}).
44
+
45
+ On \textbf{NVIDIA T4}, an isolated single-FFN timing harness (100 iters, fp32,
46
+ same $B,T$, chunk 64, $10\%$ active) shows full-$G_X$ totals from
47
+ 1.02$\times$ at $d{=}256$ to 1.35$\times$ at $d{=}2048$
48
+ (Table~\ref{tab:t4-ffn-micro}). A fused \textbf{Triton} backward passes numeric
49
+ checks (Table~\ref{tab:triton-correctness}); isolated backward on T4 improves
50
+ over dense for $d\ge 512$ but can trail PyTorch at $d{=}256$
51
+ (Table~\ref{tab:triton-backward}). Short \textbf{T4 end-to-end} training (100
52
+ steps) shows modest PyLoop gains at $d{=}512$/1024 and Triton autotune/noise
53
+ hurting at small scale (Table~\ref{tab:t4-e2e}). EMA--oracle chunk overlap on one
54
+ seed is in Table~\ref{tab:ema-overlap}; multi-seed long runs were pending at
55
+ draft time.
56
+ \end{abstract}
57
+
58
+ \section{Introduction}
59
+ Training transformers is dominated by dense matmuls. Some work reports
60
+ heavy-tailed gradient coordinates; whether that yields wall-clock savings depends
61
+ on implementation and hardware. Dynamic sparsity often hits irregular memory
62
+ access and, for variable masks, possible host--device coordination for shapes.
63
+ We use \emph{fixed-cardinality} chunk masks, an EMA scorer, cosine annealing,
64
+ and strided views (and optionally Triton) so active tiles map to dense GEMMs.
65
+ Contributions are \textbf{(1)} the algorithmic recipe, \textbf{(2)}
66
+ reproducible tables for MPS full training, T4 microbenchmarks, Triton
67
+ correctness and speed, short T4 E2E, chunk-size timing, and \textbf{(3)} honest
68
+ limits: speedups are width-, backend-, and workload-dependent.
69
+
70
+ \section{Methodology: Predictive Chunked Sparsity}
71
+ Linear $W\in\mathbb{R}^{O\times I}$ is split into $N$ row chunks of size $C$.
72
+ Binary mask $A\in\{0,1\}^N$ picks active chunks; inactive $dW$ is zeroed into
73
+ the optimizer. EMA on observed chunk norms $M_c^{(t)}=\beta
74
+ M_c^{(t-1)}+(1-\beta)\|G_{W_c}\|_2$ (active); $M_c^{(t)}=\gamma M_c^{(t-1)}$
75
+ (inactive). Top-$k$ chunks from $M^{(t-1)}$ fix $A$ at step $t$. Cosine schedule
76
+ $S(t)$ warms up fully dense then anneals toward $S_{\text{target}}$. With AdamW,
77
+ $g{=}0$ on inactive weights yields decaying moments (``phantom momentum'')---
78
+ standard Adam side effect, not a separate contribution.
79
+
80
+ \section{Systems}
81
+ Fixed $k$ avoids mask-derived index tensor sizes. Chunk rows are contiguous
82
+ slices (\texttt{gy\_flat[:, s:e]}). PyTorch normally implements that as a view;
83
+ exact behavior is version-dependent. A Python loop over active chunks issues
84
+ multiple kernel launches; Triton fusion targets that overhead (see
85
+ Table~\ref{tab:triton-backward}).
86
+
87
+ \section{Experiments and Results}
88
+ All numbers below are from recorded runs; GPU, hyperparameters, and seed are
89
+ stated per table. We do not claim universal ranking of backends.
90
+
91
+ \subsection{Full training on Apple MPS (author runs)}
92
+ Six layers, $B{=}8$, $T{=}256$, chunk\_size${=}64$, $10\%$ active chunks, 2000
93
+ optimization steps. Times are total wall for 2000 steps; ms/step derived.
94
+
95
+ \begin{table}[t]
96
+ \centering
97
+ \caption{MPS full training (single-seed author configuration per run).}
98
+ \label{tab:mps-e2e}
99
+ \begin{tabular}{l l r r r}
100
+ \toprule
101
+ $d_{\text{model}}$ & Run & Time (s) & ms/step & Val.\ loss \\
102
+ \midrule
103
+ 512 & \texttt{dense\_baseline} & 74.77 & 99.70 & 5.3142 \\
104
+ 512 & \texttt{sparse\_full\_dX} & 91.04 & 121.38 & 5.4141 \\
105
+ 512 & \texttt{sparse\_sparse\_dX} & 93.33 & 124.44 & 5.5467 \\
106
+ \midrule
107
+ 2048 & \texttt{dense\_baseline} & 1035.84 & 591.91 & 6.0264 \\
108
+ 2048 & \texttt{sparse\_full\_dX} & 875.51 & 500.29 & 5.9807 \\
109
+ 2048 & \texttt{sparse\_sparse\_dX} & 847.22 & 484.13 & 6.0231 \\
110
+ \bottomrule
111
+ \end{tabular}
112
+ \end{table}
113
+
114
+ At $d{=}512$, sparse ms/step is $\sim$1.22$\times$ (\texttt{sparse\_full\_dX})
115
+ and $\sim$1.25$\times$ (\texttt{sparse\_sparse\_dX}) vs.\ dense---\emph{slower}.
116
+ At $d{=}2048$, sparse is $\sim$1.18$\times$ and $\sim$1.22$\times$
117
+ \emph{faster}. Validation loss at $d{=}2048$ is best for
118
+ \texttt{sparse\_full\_dX} in this table; at $d{=}512$ dense is best.
119
+
120
+ \subsection{Isolated FFN layer microbenchmark (T4)}
121
+ One FFN block, $M{=}2048$, $B{=}8$, $T{=}256$, chunk\_size${=}64$, $10\%$ active,
122
+ fp32, 100 iterations. Components: forward, $dX$, $dW$ dense vs.\ sparse;
123
+ \emph{full\_$G_X$} total = sum with dense $dX$.
124
+
125
+ \begin{table}[t]
126
+ \centering
127
+ \caption{T4: per--FFN-layer times (ms). Spd. $=$ Tot.\ den.\,/{}Tot.\ sp.f.;
128
+ sparse total uses dense $dX$ (full\_dX).}
129
+ \label{tab:t4-ffn-micro}
130
+ \resizebox{\linewidth}{!}{%
131
+ \footnotesize
132
+ \begin{tabular}{r r r r r r r r r r}
133
+ \toprule
134
+ $d_{\text{model}}$ & FFN dim & Params & Fwd & $dX$ & $dW_{\mathrm{d}}$ &
135
+ $dW_{\mathrm{s}}$ & Tot.\ den. & Tot.\ sp.f. & Spd. \\
136
+ \midrule
137
+ 256 & 1024 & 0.3M & 0.27 & 0.21 & 0.27 & 0.26 & 0.75 & 0.74 & 1.02$\times$ \\
138
+ 384 & 1536 & 0.6M & 0.52 & 0.69 & 0.61 & 0.18 & 1.82 & 1.39 & 1.31$\times$ \\
139
+ 512 & 2048 & 1.0M & 1.00 & 1.01 & 0.97 & 0.26 & 2.99 & 2.28 & 1.31$\times$ \\
140
+ 768 & 3072 & 2.4M & 2.16 & 2.25 & 2.05 & 0.40 & 6.46 & 4.81 & 1.34$\times$ \\
141
+ 1024 & 4096 & 4.2M & 3.69 & 3.90 & 3.35 & 0.59 & 10.95 & 8.18 & 1.34$\times$ \\
142
+ 1536 & 6144 & 9.4M & 10.33 & 9.03 & 8.14 & 1.30 & 27.50 & 20.66 & 1.33$\times$ \\
143
+ 2048 & 8192 & 16.8M & 14.76 & 15.57 & 13.19 & 1.93 & 43.51 & 32.26 & 1.35$\times$ \\
144
+ \bottomrule
145
+ \end{tabular}%
146
+ }
147
+ \end{table}
148
+
149
+ If $dW_{\mathrm{dense}}$ were removed from the dense total, a simple
150
+ illustrative ratio (using the measured forward+$dX$ share) implies a ceiling
151
+ around $\sim$1.42--1.48$\times$ for this harness; crossover for net speedup vs.\
152
+ dense full-$G_X$ is near $d_{\text{model}}\approx 384$ in this table.
153
+
154
+ \subsection{Triton numeric checks (T4)}
155
+ Max absolute errors vs.\ reference (fp32 tolerances in experiment script); all
156
+ marked passing in the run log.
157
+
158
+ \begin{table}[t]
159
+ \centering
160
+ \caption{Triton backward vs.\ reference: max abs error.}
161
+ \label{tab:triton-correctness}
162
+ \begin{tabular}{r r r r r r r}
163
+ \toprule
164
+ $d_{\mathrm{in}}$ & $d_{\mathrm{out}}$ & ch. & $\max|dW|$ & $\max|db|$ & $\max|dX|$ & OK \\
165
+ \midrule
166
+ 512 & 2048 & 64 & 0.000320 & 0.000023 & 0.000042 & $\checkmark$ \\
167
+ 1024 & 4096 & 64 & 0.000443 & 0.000021 & 0.000092 & $\checkmark$ \\
168
+ 256 & 1024 & 32 & 0.000275 & 0.000038 & 0.000019 & $\checkmark$ \\
169
+ \bottomrule
170
+ \end{tabular}
171
+ \end{table}
172
+
173
+ \subsection{Isolated backward: Dense vs.\ PyLoop vs.\ Triton (T4)}
174
+ $M{=}2048$, chunk\_size${=}64$, $10\%$ active, full\_$G_X$ mode, 50 iterations
175
+ post-warmup. Times are full backward ms for the timed region (as recorded).
176
+
177
+ \begin{table}[t]
178
+ \centering
179
+ \caption{T4 isolated backward (ms). Triton/Dense $=$ dense/time\_triton.}
180
+ \label{tab:triton-backward}
181
+ \resizebox{\linewidth}{!}{%
182
+ \footnotesize
183
+ \begin{tabular}{r r r r r r r r r}
184
+ \toprule
185
+ $d_{\text{model}}$ & FFN & Active ch. & Dense & PyLoop & Triton &
186
+ T/Dense & T/PyLoop \\
187
+ \midrule
188
+ 256 & 1024 & 1 & 0.39 & 0.40 & 0.46 & 0.85$\times$ & 0.88$\times$ \\
189
+ 512 & 2048 & 3 & 1.96 & 1.30 & 1.16 & 1.69$\times$ & 1.12$\times$ \\
190
+ 768 & 3072 & 4 & 4.29 & 2.52 & 2.51 & 1.70$\times$ & 1.00$\times$ \\
191
+ 1024 & 4096 & 6 & 7.29 & 4.37 & 4.30 & 1.70$\times$ & 1.02$\times$ \\
192
+ 1536 & 6144 & 9 & 17.32 & 10.04 & 9.78 & 1.77$\times$ & 1.03$\times$ \\
193
+ 2048 & 8192 & 12 & 29.14 & 17.20 & 16.89 & 1.73$\times$ & 1.02$\times$ \\
194
+ \bottomrule
195
+ \end{tabular}%
196
+ }
197
+ \end{table}
198
+
199
+ \noindent\textbf{Triton with both $dW$ and $dX$ sparse} (same harness family;
200
+ user-reported row):
201
+
202
+ \begin{table}[h]
203
+ \centering
204
+ \begin{tabular}{r r r r}
205
+ \toprule
206
+ $d_{\text{model}}$ & Dense (ms) & Triton\_all (ms) & Speedup \\
207
+ \midrule
208
+ 512 & 1.96 & 0.41 & 4.83$\times$ \\
209
+ 1024 & 7.06 & 1.07 & 6.58$\times$ \\
210
+ 2048 & 29.00 & 3.71 & 7.81$\times$ \\
211
+ \bottomrule
212
+ \end{tabular}
213
+ \end{table}
214
+
215
+ At $d{=}256$, Triton is slower than dense in Table~\ref{tab:triton-backward}
216
+ (0.85$\times$); at $d{=}512$, PyTorch single-kernel launches can still be hard
217
+ to beat for only three active chunks.
218
+
219
+ \subsection{End-to-end training on T4 (100 steps)}
220
+ Six layers, 8 heads, $B{=}8$, $T{=}256$, chunk\_size${=}64$, $10\%$ active,
221
+ seed${=}42$, AdamW lr$=$5e-4, full\_$G_X$. $d{=}2048$ did not fit 16GB T4.
222
+
223
+ \begin{table}[t]
224
+ \centering
225
+ \caption{T4 E2E (100 steps); ``vs Dense'' is dense/ms\_mode.}
226
+ \label{tab:t4-e2e}
227
+ \begin{tabular}{r l r r r}
228
+ \toprule
229
+ $d_{\text{model}}$ & Mode & ms/step & vs.\ Dense & Val.\ loss \\
230
+ \midrule
231
+ 512 & dense & 184.6 & 1.00$\times$ & 5.6954 \\
232
+ 512 & pyloop & 179.0 & 1.03$\times$ & 5.8683 \\
233
+ 512 & triton & 196.0 & 0.94$\times$ & 5.8683 \\
234
+ \midrule
235
+ 1024 & dense & 451.5 & 1.00$\times$ & 5.5300 \\
236
+ 1024 & pyloop & 435.6 & 1.04$\times$ & 5.4803 \\
237
+ 1024 & triton & 441.0 & 1.02$\times$ & 5.4800 \\
238
+ \bottomrule
239
+ \end{tabular}
240
+ \end{table}
241
+
242
+ Triton E2E at $d{=}512$ is slower than dense here; autotune and short-run
243
+ overhead dominate at small scale in the author's log.
244
+
245
+ \subsection{EMA vs.\ oracle chunk overlap (T4)}
246
+ $d{=}512$, 6 layers, chunk\_size${=}64$, $10\%$ active, 350 steps, seed${=}42$;
247
+ first check step ${=}250$ post-anneal schedule. Jaccard/Recall vs.\ dense-oracle
248
+ top-$k$ (as implemented in experiment).
249
+
250
+ \begin{table}[t]
251
+ \centering
252
+ \caption{Predictor overlap (single seed; multi-seed long runs were pending).}
253
+ \label{tab:ema-overlap}
254
+ \begin{tabular}{r r r}
255
+ \toprule
256
+ Step & Jaccard & Recall \\
257
+ \midrule
258
+ 250 & 0.6000 & 0.7500 \\
259
+ 275 & 0.6552 & 0.7917 \\
260
+ 300 & 0.7778 & 0.8750 \\
261
+ 325 & 0.6000 & 0.7500 \\
262
+ \bottomrule
263
+ \end{tabular}
264
+ \end{table}
265
+
266
+ \subsection{Chunk size vs.\ step time (T4, PyLoop)}
267
+ $d{=}512$, 6 layers, $10\%$ active, seed${=}42$, 50 training steps (warmup;
268
+ loss not converged---timing only).
269
+
270
+ \begin{table}[t]
271
+ \centering
272
+ \caption{ms/step vs.\ chunk size (PyLoop backend).}
273
+ \label{tab:chunk-size}
274
+ \begin{tabular}{r r}
275
+ \toprule
276
+ Chunk size & ms/step \\
277
+ \midrule
278
+ 16 & 601.4 \\
279
+ 32 & 453.0 \\
280
+ 64 & 321.5 \\
281
+ 128 & 251.3 \\
282
+ 256 & 219.8 \\
283
+ \bottomrule
284
+ \end{tabular}
285
+ \end{table}
286
+
287
+ Larger chunks $\Rightarrow$ fewer Python iterations per layer in this backend.
288
+
289
+ \subsection{Pending experiments (snapshot)}
290
+ At draft time, additional A10G jobs were in flight, e.g.\ internal IDs
291
+ \texttt{69f38371d70108f37ace1cae} (multi-baseline 2000-step suite),
292
+ \texttt{69f395b3d70108f37ace1cee} ($d$ scaling), and
293
+ \texttt{69f3af45d2c8bd8662bd419d} (E2E Triton including $d{=}2048$). Treat these
294
+ only as lab run pointers.
295
+
296
+ \section{Conclusion}
297
+ Chunked EMA sparsity is not uniformly faster: \textbf{MPS} shows a crossover in
298
+ $d_{\text{model}}$ between 512 and 2048 for full training;
299
+ \textbf{T4} microbenchmarks monotonically favor sparse full-$G_X$ totals from
300
+ $d{\approx}384$ upward to 1.35$\times$ at $d{=}2048$ in Table~\ref{tab:t4-ffn-micro},
301
+ while \textbf{T4 E2E} at 100 steps shows small PyLoop wins and Triton not yet
302
+ winning at $d{=}512$. Triton shows large factors when both $dW$ and $dX$ are
303
+ sparse in the isolated harness, subject to training-quality tradeoffs not fully
304
+ tabulated here. Future work: complete multi-seed tables and fused-kernel E2E at
305
+ large $d$.
306
+
307
+ \end{document}
sparse_transformer_v18_fast_knn.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sparse Transformer v18: Fast Chunked Sparse Backward + KNN Sensor Scheduler.
3
+
4
+ This plugs the v16/v17 KNN sensor scheduler into the real chunked sparse backward path.
5
+ It compares dense, EMA-topk sparse, KNN sparse, and random sparse in full_dX and sparse_dX modes.
6
+
7
+ Run:
8
+ python3 sparse_transformer_v18_fast_knn.py --device mps --benchmark_sync
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import argparse
14
+ import math
15
+ import random
16
+ import time
17
+ from typing import Dict, List, Literal, Optional, Tuple
18
+
19
+ import torch
20
+
21
+ torch.set_num_threads(1)
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+ Scheduler = Literal["dense", "ema_topk", "knn_scheduler", "random"]
26
+ BackwardMode = Literal["dense_baseline", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
27
+
28
+
29
+ def sync_device(device: str) -> None:
30
+ if device == "cuda" and torch.cuda.is_available():
31
+ torch.cuda.synchronize()
32
+ elif device == "mps" and hasattr(torch, "mps"):
33
+ torch.mps.synchronize()
34
+
35
+
36
+ def set_seed(seed: int) -> None:
37
+ random.seed(seed)
38
+ torch.manual_seed(seed)
39
+
40
+
41
+ def make_cpu_generator(seed: int) -> torch.Generator:
42
+ gen = torch.Generator(device="cpu")
43
+ gen.manual_seed(seed)
44
+ return gen
45
+
46
+
47
+ def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
48
+ rng = random.Random(seed)
49
+ words = [
50
+ "ada", "turing", "grace", "lovelace", "gradients", "tokens", "circuits",
51
+ "features", "boldly", "strangely", "matrix", "attention", "kernel", "entropy", "signal",
52
+ ]
53
+ return "\n".join(
54
+ " ".join(rng.choices(words, k=rng.randint(4, 10))) + "." for _ in range(n_sentences)
55
+ )
56
+
57
+
58
+ class CharCorpus:
59
+ def __init__(self, text: str, block_size: int, device: str):
60
+ chars = sorted(set(text))
61
+ self.stoi = {ch: i for i, ch in enumerate(chars)}
62
+ self.vocab_size = len(chars)
63
+ self.block_size = block_size
64
+ self.device = device
65
+ data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
66
+ self.train_data = data[: int(0.9 * len(data))]
67
+ self.val_data = data[int(0.9 * len(data)) :]
68
+
69
+ def get_batch(self, split: str, batch_size: int, generator: Optional[torch.Generator] = None):
70
+ data = self.train_data if split == "train" else self.val_data
71
+ ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator)
72
+ x = torch.stack([data[i : i + self.block_size] for i in ix])
73
+ y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
74
+ return x.to(self.device), y.to(self.device)
75
+
76
+
77
+ class ChunkedMaskedLinear(torch.autograd.Function):
78
+ @staticmethod
79
+ def forward(ctx, x, weight, bias, active_chunks, chunk_size: int, sparse_dx: bool):
80
+ ctx.save_for_backward(x, weight, active_chunks)
81
+ ctx.has_bias = bias is not None
82
+ ctx.sparse_dx = bool(sparse_dx)
83
+ ctx.chunk_size = int(chunk_size)
84
+ return F.linear(x, weight, bias)
85
+
86
+ @staticmethod
87
+ def backward(ctx, grad_y):
88
+ x, weight, active_chunks = ctx.saved_tensors
89
+ chunk_size = ctx.chunk_size
90
+ x_flat = x.reshape(-1, x.shape[-1])
91
+ gy_flat = grad_y.reshape(-1, grad_y.shape[-1])
92
+ grad_w = torch.zeros_like(weight)
93
+ grad_b = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if ctx.has_bias else None
94
+ grad_x_flat = torch.zeros_like(x_flat) if ctx.sparse_dx else gy_flat @ weight
95
+
96
+ for c_idx in active_chunks.tolist():
97
+ start = int(c_idx) * chunk_size
98
+ end = start + chunk_size
99
+ gy_slice = gy_flat[:, start:end]
100
+ w_slice = weight[start:end, :]
101
+ grad_w[start:end, :] = gy_slice.transpose(0, 1) @ x_flat
102
+ if grad_b is not None:
103
+ grad_b[start:end] = gy_slice.sum(dim=0)
104
+ if ctx.sparse_dx:
105
+ grad_x_flat += gy_slice @ w_slice
106
+ return grad_x_flat.reshape(x.shape), grad_w, grad_b, None, None, None
107
+
108
+
109
+ class SparseLinear(nn.Linear):
110
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
111
+ super().__init__(in_features, out_features, bias=bias)
112
+ self.sparse_enabled = False
113
+ self.sparse_dx = False
114
+ self.active_chunks: Optional[torch.Tensor] = None
115
+ self.chunk_size = 64
116
+
117
+ def forward(self, x):
118
+ if not self.sparse_enabled or self.active_chunks is None:
119
+ return F.linear(x, self.weight, self.bias)
120
+ return ChunkedMaskedLinear.apply(x, self.weight, self.bias, self.active_chunks, self.chunk_size, self.sparse_dx)
121
+
122
+
123
+ class CausalSelfAttention(nn.Module):
124
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
125
+ super().__init__()
126
+ assert n_embd % n_head == 0
127
+ self.n_head = n_head
128
+ self.head_dim = n_embd // n_head
129
+ self.c_attn = SparseLinear(n_embd, 3 * n_embd)
130
+ self.c_proj = SparseLinear(n_embd, n_embd)
131
+ self.dropout = nn.Dropout(dropout)
132
+ self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
133
+
134
+ def forward(self, x):
135
+ B, T, C = x.shape
136
+ qkv = self.c_attn(x)
137
+ q, k, v = qkv.split(C, dim=2)
138
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
139
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
140
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
141
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
142
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
143
+ att = F.softmax(att, dim=-1)
144
+ att = self.dropout(att)
145
+ y = att @ v
146
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
147
+ return self.c_proj(y)
148
+
149
+
150
+ class FeedForward(nn.Module):
151
+ def __init__(self, n_embd: int, dropout: float):
152
+ super().__init__()
153
+ self.c_fc = SparseLinear(n_embd, 4 * n_embd)
154
+ self.c_proj = SparseLinear(4 * n_embd, n_embd)
155
+ self.dropout = nn.Dropout(dropout)
156
+
157
+ def forward(self, x):
158
+ return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
159
+
160
+
161
+ class Block(nn.Module):
162
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
163
+ super().__init__()
164
+ self.ln1 = nn.LayerNorm(n_embd)
165
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
166
+ self.ln2 = nn.LayerNorm(n_embd)
167
+ self.mlp = FeedForward(n_embd, dropout)
168
+
169
+ def forward(self, x):
170
+ x = x + self.attn(self.ln1(x))
171
+ x = x + self.mlp(self.ln2(x))
172
+ return x
173
+
174
+
175
+ class MiniGPT(nn.Module):
176
+ def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float):
177
+ super().__init__()
178
+ self.tok_emb = nn.Embedding(vocab_size, n_embd)
179
+ self.pos_emb = nn.Embedding(block_size, n_embd)
180
+ self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
181
+ self.ln_f = nn.LayerNorm(n_embd)
182
+ self.lm_head = nn.Linear(n_embd, vocab_size)
183
+
184
+ def forward(self, idx, targets=None):
185
+ B, T = idx.shape
186
+ pos = torch.arange(T, device=idx.device)
187
+ x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
188
+ x = self.blocks(x)
189
+ x = self.ln_f(x)
190
+ logits = self.lm_head(x)
191
+ loss = None
192
+ if targets is not None:
193
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
194
+ return logits, loss
195
+
196
+
197
+ def get_sparse_linears(model):
198
+ return [m for m in model.modules() if isinstance(m, SparseLinear)]
199
+
200
+
201
+ class FastChunkScheduler:
202
+ def __init__(self, model, scheduler: Scheduler, target_fraction: float, chunk_size: int, device: str,
203
+ mass_beta: float = 0.95, similarity_history: int = 128, min_similarity_history: int = 8,
204
+ knn_k: int = 3):
205
+ self.scheduler = scheduler
206
+ self.target_fraction = target_fraction
207
+ self.chunk_size = chunk_size
208
+ self.device = device
209
+ self.mass_beta = mass_beta
210
+ self.similarity_history = similarity_history
211
+ self.min_similarity_history = min_similarity_history
212
+ self.knn_k = knn_k
213
+ self.linears = get_sparse_linears(model)
214
+ self.module_to_chunk_ids = {}
215
+ self.module_to_local_ids = {}
216
+ offset = 0
217
+ for m in self.linears:
218
+ m.chunk_size = chunk_size
219
+ assert m.out_features % chunk_size == 0
220
+ n_chunks = m.out_features // chunk_size
221
+ ids = torch.arange(offset, offset + n_chunks, device=device)
222
+ self.module_to_chunk_ids[m] = ids
223
+ self.module_to_local_ids[m] = torch.arange(n_chunks, device=device)
224
+ offset += n_chunks
225
+ self.n_chunks = offset
226
+ self.predicted_mass = torch.zeros(self.n_chunks, device=device)
227
+ self.mass_history = []
228
+ self.similarity = None
229
+ self.active_chunks = torch.zeros(self.n_chunks, dtype=torch.bool, device=device)
230
+ self.sensor_scores = torch.zeros(self.n_chunks, device=device)
231
+
232
+ def current_fraction(self, step: int, warmup_steps: int, anneal_steps: int) -> float:
233
+ if self.scheduler == "dense" or step < warmup_steps:
234
+ return 1.0
235
+ if anneal_steps > 0 and step < warmup_steps + anneal_steps:
236
+ progress = (step - warmup_steps) / anneal_steps
237
+ cosine_mult = 0.5 * (1.0 + math.cos(math.pi * progress))
238
+ return self.target_fraction + (1.0 - self.target_fraction) * cosine_mult
239
+ return self.target_fraction
240
+
241
+ def choose_active(self, step: int, warmup_steps: int, anneal_steps: int):
242
+ frac = self.current_fraction(step, warmup_steps, anneal_steps)
243
+ if frac >= 0.999 or self.scheduler == "dense":
244
+ self.active_chunks.fill_(True)
245
+ self.install_local_masks()
246
+ return self.active_chunks
247
+ k = max(1, int(frac * self.n_chunks))
248
+ self.active_chunks.fill_(False)
249
+ if self.scheduler == "random":
250
+ idx = torch.randperm(self.n_chunks, device=self.device)[:k]
251
+ elif self.scheduler == "ema_topk":
252
+ scores = self.predicted_mass + 1e-9 * torch.rand_like(self.predicted_mass)
253
+ idx = torch.topk(scores, k=k).indices
254
+ elif self.scheduler == "knn_scheduler":
255
+ base = self.sensor_scores if torch.count_nonzero(self.sensor_scores).item() else self.predicted_mass
256
+ scores = base + 1e-9 * torch.rand_like(base)
257
+ idx = torch.topk(scores, k=k).indices
258
+ else:
259
+ raise ValueError(f"Unknown scheduler: {self.scheduler}")
260
+ self.active_chunks[idx] = True
261
+ self.install_local_masks()
262
+ return self.active_chunks
263
+
264
+ def install_local_masks(self):
265
+ for m, global_ids in self.module_to_chunk_ids.items():
266
+ local = self.module_to_local_ids[m]
267
+ m.active_chunks = local[self.active_chunks[global_ids]]
268
+
269
+ @torch.no_grad()
270
+ def update_from_active_gradients(self, step: int, warmup_steps: int):
271
+ current_mass = torch.zeros_like(self.predicted_mass)
272
+ for m, ids in self.module_to_chunk_ids.items():
273
+ if m.weight.grad is None:
274
+ continue
275
+ w_sq = m.weight.grad.square().view(len(ids), self.chunk_size, -1).sum(dim=(1, 2))
276
+ if m.bias is not None and m.bias.grad is not None:
277
+ w_sq += m.bias.grad.square().view(len(ids), self.chunk_size).sum(dim=1)
278
+ current_mass[ids] = torch.sqrt(w_sq + 1e-30)
279
+ observed = self.active_chunks
280
+ never_seen = observed & (self.predicted_mass == 0)
281
+ already_seen = observed & ~never_seen
282
+ self.predicted_mass[never_seen] = current_mass[never_seen]
283
+ self.predicted_mass[already_seen] = self.mass_beta * self.predicted_mass[already_seen] + (1 - self.mass_beta) * current_mass[already_seen]
284
+ if step < warmup_steps:
285
+ self.mass_history.append(current_mass.detach().clone())
286
+ if len(self.mass_history) > self.similarity_history:
287
+ self.mass_history = self.mass_history[-self.similarity_history:]
288
+ if len(self.mass_history) >= self.min_similarity_history:
289
+ self.similarity = self.build_similarity()
290
+ self.sensor_scores = self.knn_scores(self.active_chunks, current_mass) if self.scheduler == "knn_scheduler" else self.predicted_mass.clone()
291
+
292
+ def build_similarity(self):
293
+ H = torch.stack(self.mass_history, dim=0)
294
+ H = H - H.mean(dim=0, keepdim=True)
295
+ H = H / (H.std(dim=0, keepdim=True) + 1e-6)
296
+ S = (H.T @ H) / max(1, H.shape[0] - 1)
297
+ S = torch.clamp(S, min=0.0)
298
+ S.fill_diagonal_(0.0)
299
+ allowed = torch.zeros_like(S, dtype=torch.bool)
300
+ for _, ids in self.module_to_chunk_ids.items():
301
+ allowed[ids[:, None], ids[None, :]] = True
302
+ return torch.where(allowed, S, torch.zeros_like(S))
303
+
304
+ def knn_scores(self, active_mask, current_mass):
305
+ if self.similarity is None:
306
+ return self.predicted_mass.clone()
307
+ scores = self.predicted_mass.clone()
308
+ scores[active_mask] = current_mass[active_mask]
309
+ active_idx = torch.nonzero(active_mask, as_tuple=False).flatten()
310
+ inactive_idx = torch.nonzero(~active_mask, as_tuple=False).flatten()
311
+ if active_idx.numel() == 0:
312
+ return scores
313
+ S = self.similarity
314
+ for i in inactive_idx.tolist():
315
+ weights = S[i, active_idx]
316
+ if weights.sum() <= 1e-12:
317
+ continue
318
+ kk = min(self.knn_k, weights.numel())
319
+ top = torch.topk(weights, k=kk)
320
+ w = top.values
321
+ aidx = active_idx[top.indices]
322
+ scores[i] = (w * current_mass[aidx]).sum() / (w.sum() + 1e-12)
323
+ return scores
324
+
325
+
326
+ class ChunkedAdam:
327
+ def __init__(self, model, lr=3e-4, chunk_size=64):
328
+ self.model = model
329
+ self.lr = lr
330
+ self.chunk_size = chunk_size
331
+ self.state = {}
332
+ self.param_to_sparse_module = {}
333
+ for m in get_sparse_linears(model):
334
+ if m.weight is not None:
335
+ self.param_to_sparse_module[m.weight] = m
336
+ if m.bias is not None:
337
+ self.param_to_sparse_module[m.bias] = m
338
+
339
+ def zero_grad(self):
340
+ for p in self.model.parameters():
341
+ p.grad = None
342
+
343
+ @torch.no_grad()
344
+ def step(self):
345
+ for p in self.model.parameters():
346
+ if p.grad is None:
347
+ continue
348
+ if p not in self.state:
349
+ self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
350
+ exp_avg = self.state[p]["m"]
351
+ exp_avg_sq = self.state[p]["v"]
352
+ sparse_module = self.param_to_sparse_module.get(p)
353
+ active_chunks = getattr(sparse_module, "active_chunks", None) if sparse_module else None
354
+ if active_chunks is None:
355
+ exp_avg.mul_(0.9).add_(p.grad, alpha=0.1)
356
+ exp_avg_sq.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
357
+ p.sub_(exp_avg / (torch.sqrt(exp_avg_sq) + 1e-8), alpha=self.lr)
358
+ else:
359
+ for local_c in active_chunks.tolist():
360
+ start = int(local_c) * self.chunk_size
361
+ end = start + self.chunk_size
362
+ p_chunk = p[start:end]
363
+ g_chunk = p.grad[start:end]
364
+ m_chunk = exp_avg[start:end]
365
+ v_chunk = exp_avg_sq[start:end]
366
+ m_chunk.mul_(0.9).add_(g_chunk, alpha=0.1)
367
+ v_chunk.mul_(0.999).addcmul_(g_chunk, g_chunk, value=0.001)
368
+ p_chunk.sub_(m_chunk / (torch.sqrt(v_chunk) + 1e-8), alpha=self.lr)
369
+
370
+
371
+ def evaluate(model, corpus, batch_size, seed):
372
+ model.eval()
373
+ with torch.no_grad():
374
+ x, y = corpus.get_batch("val", batch_size, generator=make_cpu_generator(seed))
375
+ _, loss = model(x, y)
376
+ model.train()
377
+ return float(loss.item())
378
+
379
+
380
+ def run_one(label, scheduler, mode, args):
381
+ set_seed(42)
382
+ corpus = CharCorpus(make_synthetic_corpus(), args.block_size, args.device)
383
+ model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, 0.0).to(args.device)
384
+ sched = FastChunkScheduler(model, scheduler, args.active_fraction, args.chunk_size, args.device)
385
+ opt = ChunkedAdam(model, lr=args.lr, chunk_size=args.chunk_size)
386
+ measured_steps = args.steps
387
+ if args.benchmark_sync:
388
+ sync_device(args.device)
389
+ t0 = time.perf_counter()
390
+ for step in range(args.steps):
391
+ if step == args.warmup_steps + args.anneal_steps:
392
+ if args.benchmark_sync:
393
+ sync_device(args.device)
394
+ t0 = time.perf_counter()
395
+ measured_steps = args.steps - step
396
+ if scheduler == "dense" or mode == "dense_baseline":
397
+ for m in get_sparse_linears(model):
398
+ m.sparse_enabled = False
399
+ m.active_chunks = None
400
+ else:
401
+ sched.choose_active(step, args.warmup_steps, args.anneal_steps)
402
+ for m in get_sparse_linears(model):
403
+ m.sparse_enabled = True
404
+ m.sparse_dx = mode == "sparse_dW_sparse_dX"
405
+ x, y = corpus.get_batch("train", args.batch_size, generator=make_cpu_generator(step))
406
+ opt.zero_grad()
407
+ _, loss = model(x, y)
408
+ loss.backward()
409
+ if scheduler != "dense" and mode != "dense_baseline":
410
+ sched.update_from_active_gradients(step, args.warmup_steps)
411
+ opt.step()
412
+ if args.benchmark_sync:
413
+ sync_device(args.device)
414
+ elapsed = time.perf_counter() - t0
415
+ return {"val": evaluate(model, corpus, args.batch_size, 12345), "ms": 1000 * elapsed / max(1, measured_steps)}
416
+
417
+
418
+ def main():
419
+ parser = argparse.ArgumentParser()
420
+ parser.add_argument("--steps", type=int, default=500)
421
+ parser.add_argument("--batch_size", type=int, default=16)
422
+ parser.add_argument("--block_size", type=int, default=256)
423
+ parser.add_argument("--n_layer", type=int, default=4)
424
+ parser.add_argument("--n_head", type=int, default=16)
425
+ parser.add_argument("--n_embd", type=int, default=1024)
426
+ parser.add_argument("--chunk_size", type=int, default=64)
427
+ parser.add_argument("--active_fraction", type=float, default=0.10)
428
+ parser.add_argument("--warmup_steps", type=int, default=25)
429
+ parser.add_argument("--anneal_steps", type=int, default=150)
430
+ parser.add_argument("--lr", type=float, default=3e-4)
431
+ parser.add_argument("--device", type=str, default="mps")
432
+ parser.add_argument("--benchmark_sync", action="store_true")
433
+ args = parser.parse_args()
434
+ runs = [
435
+ ("dense", "dense", "dense_baseline"),
436
+ ("ema_full_dX", "ema_topk", "sparse_dW_full_dX"),
437
+ ("knn_full_dX", "knn_scheduler", "sparse_dW_full_dX"),
438
+ ("random_full_dX", "random", "sparse_dW_full_dX"),
439
+ ("ema_sparse_dX", "ema_topk", "sparse_dW_sparse_dX"),
440
+ ("knn_sparse_dX", "knn_scheduler", "sparse_dW_sparse_dX"),
441
+ ("random_sparse_dX", "random", "sparse_dW_sparse_dX"),
442
+ ]
443
+ print("\nFast chunked sparse backward with KNN scheduler")
444
+ print(f"device={args.device} steps={args.steps} d={args.n_embd} layers={args.n_layer}")
445
+ print(f"batch={args.batch_size} block={args.block_size} chunk={args.chunk_size}")
446
+ print(f"active={args.active_fraction} warmup={args.warmup_steps} anneal={args.anneal_steps}\n")
447
+ print(f"{'run':>18s} | {'val':>8s} | {'ms/step':>8s} | {'speedup':>8s}")
448
+ print("-" * 58)
449
+ dense_ms = None
450
+ for label, scheduler, mode in runs:
451
+ result = run_one(label, scheduler, mode, args)
452
+ if label == "dense":
453
+ dense_ms = result["ms"]
454
+ speedup = dense_ms / result["ms"] if dense_ms else float("nan")
455
+ print(f"{label:>18s} | {result['val']:8.4f} | {result['ms']:8.2f} | {speedup:8.3f}")
456
+
457
+
458
+ if __name__ == "__main__":
459
+ main()
sparse_transformer_v18_fast_knn_triton.py ADDED
@@ -0,0 +1,1044 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Sparse Transformer v19: Triton-backed v18 KNN Scheduler.
4
+
5
+ This is the CUDA/Triton version of the fast chunked sparse-backward loop.
6
+
7
+ It combines:
8
+ - chunked sparse Linear backward
9
+ - Triton fused active-chunk dW + dBias
10
+ - optional Triton sparse dX
11
+ - EMA / KNN sensor scheduler / random support
12
+ - chunked sparse Adam update
13
+
14
+ Core modes:
15
+ dense
16
+ ema_full_dX
17
+ knn_full_dX
18
+ random_full_dX
19
+ ema_sparse_dX
20
+ knn_sparse_dX
21
+ random_sparse_dX
22
+
23
+ Safe mode:
24
+ knn_full_dX
25
+ Forward dense, dW/db sparse, dX full.
26
+
27
+ Aggressive mode:
28
+ knn_sparse_dX
29
+ Forward dense, dW/db sparse, dX sparse through active chunks.
30
+
31
+ Run:
32
+ python3 sparse_transformer_v19_triton_knn.py --device cuda --benchmark_sync
33
+
34
+ Useful:
35
+ python3 sparse_transformer_v19_triton_knn.py --device cuda --steps 500 --n_embd 1024 --benchmark_sync
36
+ python3 sparse_transformer_v19_triton_knn.py --device cuda --steps 500 --n_embd 2048 --benchmark_sync
37
+
38
+ Notes:
39
+ - This script needs CUDA + Triton.
40
+ - No autotune. Fixed configs reduce compile noise and keep comparisons stable.
41
+ - dW+dBias is fused.
42
+ - Uses block_ptr/tiled loads. On T4 this is not Hopper TMA; do not call it TMA.
43
+ """
44
+
45
+ from __future__ import annotations
46
+
47
+ import argparse
48
+ import math
49
+ import random
50
+ import time
51
+ from typing import Dict, List, Literal, Optional, Tuple
52
+
53
+ import torch
54
+
55
+ torch.set_num_threads(1)
56
+ import torch.nn as nn
57
+ import torch.nn.functional as F
58
+
59
+ try:
60
+ import triton
61
+ import triton.language as tl
62
+ TRITON_AVAILABLE = True
63
+ except Exception:
64
+ triton = None
65
+ tl = None
66
+ TRITON_AVAILABLE = False
67
+
68
+
69
+ Scheduler = Literal["dense", "ema_topk", "knn_scheduler", "random"]
70
+ BackwardMode = Literal["dense_baseline", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
71
+ KernelBackend = Literal["triton", "torch"]
72
+
73
+
74
+ # ================================================================
75
+ # Utilities
76
+ # ================================================================
77
+
78
+ def sync_device(device: str) -> None:
79
+ if device == "cuda" and torch.cuda.is_available():
80
+ torch.cuda.synchronize()
81
+ elif device == "mps" and hasattr(torch, "mps"):
82
+ torch.mps.synchronize()
83
+
84
+
85
+ def set_seed(seed: int) -> None:
86
+ random.seed(seed)
87
+ torch.manual_seed(seed)
88
+ if torch.cuda.is_available():
89
+ torch.cuda.manual_seed_all(seed)
90
+
91
+
92
+ def make_cpu_generator(seed: int) -> torch.Generator:
93
+ gen = torch.Generator(device="cpu")
94
+ gen.manual_seed(seed)
95
+ return gen
96
+
97
+
98
+ # ================================================================
99
+ # Data
100
+ # ================================================================
101
+
102
+ def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
103
+ rng = random.Random(seed)
104
+ words = [
105
+ "ada", "turing", "grace", "lovelace", "gradients",
106
+ "tokens", "circuits", "features", "boldly", "strangely",
107
+ "matrix", "attention", "kernel", "entropy", "signal",
108
+ ]
109
+ return "\n".join(
110
+ " ".join(rng.choices(words, k=rng.randint(4, 10))) + "."
111
+ for _ in range(n_sentences)
112
+ )
113
+
114
+
115
+ class CharCorpus:
116
+ def __init__(self, text: str, block_size: int, device: str):
117
+ chars = sorted(set(text))
118
+ self.stoi = {ch: i for i, ch in enumerate(chars)}
119
+ self.itos = {i: ch for ch, i in self.stoi.items()}
120
+ self.vocab_size = len(chars)
121
+ self.block_size = block_size
122
+ self.device = device
123
+
124
+ data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
125
+ self.train_data = data[: int(0.9 * len(data))]
126
+ self.val_data = data[int(0.9 * len(data)) :]
127
+
128
+ def get_batch(
129
+ self,
130
+ split: str,
131
+ batch_size: int,
132
+ generator: Optional[torch.Generator] = None,
133
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
134
+ data = self.train_data if split == "train" else self.val_data
135
+ ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator)
136
+ x = torch.stack([data[i : i + self.block_size] for i in ix])
137
+ y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
138
+ return x.to(self.device), y.to(self.device)
139
+
140
+
141
+ # ================================================================
142
+ # Triton sparse Linear backward kernels
143
+ # ================================================================
144
+
145
+ if TRITON_AVAILABLE:
146
+ @triton.jit
147
+ def _triton_sparse_bwd_dW_db_kernel(
148
+ X_ptr, dY_ptr, dW_ptr, dB_ptr, chunk_ids_ptr,
149
+ M: tl.constexpr, d_in: tl.constexpr, d_out: tl.constexpr, num_active: tl.constexpr,
150
+ stride_xm: tl.constexpr, stride_xk: tl.constexpr,
151
+ stride_dym: tl.constexpr, stride_dyn: tl.constexpr,
152
+ stride_dwn: tl.constexpr, stride_dwk: tl.constexpr,
153
+ HAS_BIAS: tl.constexpr,
154
+ CS: tl.constexpr, BK: tl.constexpr, BM: tl.constexpr,
155
+ ):
156
+ """
157
+ One program computes one [CS, BK] dW tile for one active chunk.
158
+ Bias for that chunk is fused into k_block_id == 0.
159
+ Grid: (num_active, ceil(d_in / BK))
160
+ """
161
+ chunk_linear_id = tl.program_id(0)
162
+ k_block_id = tl.program_id(1)
163
+
164
+ chunk_idx = tl.load(chunk_ids_ptr + chunk_linear_id)
165
+ chunk_start = chunk_idx * CS
166
+ k_offset = k_block_id * BK
167
+
168
+ dy_block_ptr = tl.make_block_ptr(
169
+ base=dY_ptr,
170
+ shape=(d_out, M),
171
+ strides=(stride_dyn, stride_dym),
172
+ offsets=(chunk_start, 0),
173
+ block_shape=(CS, BM),
174
+ order=(1, 0),
175
+ )
176
+
177
+ x_block_ptr = tl.make_block_ptr(
178
+ base=X_ptr,
179
+ shape=(M, d_in),
180
+ strides=(stride_xm, stride_xk),
181
+ offsets=(0, k_offset),
182
+ block_shape=(BM, BK),
183
+ order=(1, 0),
184
+ )
185
+
186
+ acc_dw = tl.zeros((CS, BK), dtype=tl.float32)
187
+ compute_bias = HAS_BIAS and (k_block_id == 0)
188
+ acc_db = tl.zeros((CS,), dtype=tl.float32)
189
+
190
+ for _ in range(0, M, BM):
191
+ dy_t = tl.load(dy_block_ptr, boundary_check=(0, 1)) # [CS, BM]
192
+ x = tl.load(x_block_ptr, boundary_check=(0, 1)) # [BM, BK]
193
+
194
+ acc_dw = tl.dot(dy_t, x, acc=acc_dw)
195
+
196
+ if compute_bias:
197
+ acc_db += tl.sum(dy_t, axis=1)
198
+
199
+ dy_block_ptr = tl.advance(dy_block_ptr, (0, BM))
200
+ x_block_ptr = tl.advance(x_block_ptr, (BM, 0))
201
+
202
+ dw_block_ptr = tl.make_block_ptr(
203
+ base=dW_ptr,
204
+ shape=(d_out, d_in),
205
+ strides=(stride_dwn, stride_dwk),
206
+ offsets=(chunk_start, k_offset),
207
+ block_shape=(CS, BK),
208
+ order=(1, 0),
209
+ )
210
+ tl.store(dw_block_ptr, acc_dw.to(dW_ptr.dtype.element_ty), boundary_check=(0, 1))
211
+
212
+ if compute_bias:
213
+ rn = chunk_start + tl.arange(0, CS)
214
+ tl.store(dB_ptr + rn, acc_db.to(dB_ptr.dtype.element_ty), mask=rn < d_out)
215
+
216
+
217
+ @triton.jit
218
+ def _triton_sparse_bwd_dX_kernel(
219
+ dY_ptr, W_ptr, dX_ptr, chunk_ids_ptr,
220
+ M: tl.constexpr, d_in: tl.constexpr, d_out: tl.constexpr, num_active: tl.constexpr,
221
+ stride_dym: tl.constexpr, stride_dyn: tl.constexpr,
222
+ stride_wn: tl.constexpr, stride_wk: tl.constexpr,
223
+ stride_dxm: tl.constexpr, stride_dxk: tl.constexpr,
224
+ CS: tl.constexpr, BM: tl.constexpr, BK: tl.constexpr,
225
+ ):
226
+ """
227
+ One program computes one [BM, BK] tile of dX by accumulating over active chunks.
228
+ Grid: (ceil(M/BM), ceil(d_in/BK)).
229
+ """
230
+ pid_m = tl.program_id(0)
231
+ pid_k = tl.program_id(1)
232
+
233
+ m_offset = pid_m * BM
234
+ k_offset = pid_k * BK
235
+
236
+ acc = tl.zeros((BM, BK), dtype=tl.float32)
237
+
238
+ for i in range(0, num_active):
239
+ chunk_idx = tl.load(chunk_ids_ptr + i)
240
+ chunk_start = chunk_idx * CS
241
+
242
+ dy_block_ptr = tl.make_block_ptr(
243
+ base=dY_ptr,
244
+ shape=(M, d_out),
245
+ strides=(stride_dym, stride_dyn),
246
+ offsets=(m_offset, chunk_start),
247
+ block_shape=(BM, CS),
248
+ order=(1, 0),
249
+ )
250
+
251
+ w_block_ptr = tl.make_block_ptr(
252
+ base=W_ptr,
253
+ shape=(d_out, d_in),
254
+ strides=(stride_wn, stride_wk),
255
+ offsets=(chunk_start, k_offset),
256
+ block_shape=(CS, BK),
257
+ order=(1, 0),
258
+ )
259
+
260
+ dy = tl.load(dy_block_ptr, boundary_check=(0, 1)) # [BM, CS]
261
+ w = tl.load(w_block_ptr, boundary_check=(0, 1)) # [CS, BK]
262
+ acc = tl.dot(dy, w, acc=acc)
263
+
264
+ dx_block_ptr = tl.make_block_ptr(
265
+ base=dX_ptr,
266
+ shape=(M, d_in),
267
+ strides=(stride_dxm, stride_dxk),
268
+ offsets=(m_offset, k_offset),
269
+ block_shape=(BM, BK),
270
+ order=(1, 0),
271
+ )
272
+ tl.store(dx_block_ptr, acc.to(dX_ptr.dtype.element_ty), boundary_check=(0, 1))
273
+
274
+
275
+ def triton_sparse_bwd_dW_db(
276
+ x_flat: torch.Tensor,
277
+ gy_flat: torch.Tensor,
278
+ active_chunks: torch.Tensor,
279
+ chunk_size: int,
280
+ d_out: int,
281
+ has_bias: bool,
282
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
283
+ if not TRITON_AVAILABLE:
284
+ raise RuntimeError("Triton is not available")
285
+
286
+ M, d_in = x_flat.shape
287
+ num_active = int(active_chunks.numel())
288
+
289
+ dW = torch.zeros((d_out, d_in), device=x_flat.device, dtype=x_flat.dtype)
290
+ dB = torch.zeros((d_out,), device=x_flat.device, dtype=x_flat.dtype) if has_bias else None
291
+
292
+ if num_active == 0:
293
+ return dW, dB
294
+
295
+ chunk_ids = active_chunks.to(torch.int32).contiguous()
296
+
297
+ # Fixed configs.
298
+ CS = int(chunk_size)
299
+ BK = 64
300
+ BM = 64
301
+
302
+ grid = (num_active, triton.cdiv(d_in, BK))
303
+
304
+ _triton_sparse_bwd_dW_db_kernel[grid](
305
+ x_flat, gy_flat, dW, dB if has_bias else dW, chunk_ids,
306
+ M, d_in, d_out, num_active,
307
+ x_flat.stride(0), x_flat.stride(1),
308
+ gy_flat.stride(0), gy_flat.stride(1),
309
+ dW.stride(0), dW.stride(1),
310
+ HAS_BIAS=has_bias,
311
+ CS=CS, BK=BK, BM=BM,
312
+ num_warps=4,
313
+ )
314
+
315
+ return dW, dB
316
+
317
+
318
+ def triton_sparse_bwd_dX(
319
+ gy_flat: torch.Tensor,
320
+ weight: torch.Tensor,
321
+ active_chunks: torch.Tensor,
322
+ chunk_size: int,
323
+ M: int,
324
+ d_in: int,
325
+ ) -> torch.Tensor:
326
+ if not TRITON_AVAILABLE:
327
+ raise RuntimeError("Triton is not available")
328
+
329
+ num_active = int(active_chunks.numel())
330
+ d_out = gy_flat.shape[1]
331
+ dX = torch.zeros((M, d_in), device=gy_flat.device, dtype=gy_flat.dtype)
332
+
333
+ if num_active == 0:
334
+ return dX
335
+
336
+ chunk_ids = active_chunks.to(torch.int32).contiguous()
337
+
338
+ CS = int(chunk_size)
339
+ BM = 64
340
+ BK = 64
341
+
342
+ grid = (triton.cdiv(M, BM), triton.cdiv(d_in, BK))
343
+
344
+ _triton_sparse_bwd_dX_kernel[grid](
345
+ gy_flat, weight, dX, chunk_ids,
346
+ M, d_in, d_out, num_active,
347
+ gy_flat.stride(0), gy_flat.stride(1),
348
+ weight.stride(0), weight.stride(1),
349
+ dX.stride(0), dX.stride(1),
350
+ CS=CS, BM=BM, BK=BK,
351
+ num_warps=4,
352
+ )
353
+
354
+ return dX
355
+
356
+
357
+ # ================================================================
358
+ # Sparse Linear autograd
359
+ # ================================================================
360
+
361
+ class ChunkedMaskedLinearTorch(torch.autograd.Function):
362
+ @staticmethod
363
+ def forward(
364
+ ctx,
365
+ x: torch.Tensor,
366
+ weight: torch.Tensor,
367
+ bias: Optional[torch.Tensor],
368
+ active_chunks: torch.Tensor,
369
+ chunk_size: int,
370
+ sparse_dx: bool,
371
+ ) -> torch.Tensor:
372
+ ctx.save_for_backward(x, weight, active_chunks)
373
+ ctx.has_bias = bias is not None
374
+ ctx.sparse_dx = bool(sparse_dx)
375
+ ctx.chunk_size = int(chunk_size)
376
+ return F.linear(x, weight, bias)
377
+
378
+ @staticmethod
379
+ def backward(ctx, grad_y: torch.Tensor):
380
+ x, weight, active_chunks = ctx.saved_tensors
381
+ chunk_size = ctx.chunk_size
382
+
383
+ x_flat = x.reshape(-1, x.shape[-1])
384
+ gy_flat = grad_y.reshape(-1, grad_y.shape[-1])
385
+
386
+ grad_w = torch.zeros_like(weight)
387
+ grad_b = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if ctx.has_bias else None
388
+
389
+ if ctx.sparse_dx:
390
+ grad_x_flat = torch.zeros_like(x_flat)
391
+ else:
392
+ grad_x_flat = gy_flat @ weight
393
+
394
+ for c_idx in active_chunks.tolist():
395
+ start = int(c_idx) * chunk_size
396
+ end = start + chunk_size
397
+
398
+ gy_slice = gy_flat[:, start:end]
399
+ w_slice = weight[start:end, :]
400
+
401
+ grad_w[start:end, :] = gy_slice.transpose(0, 1) @ x_flat
402
+
403
+ if grad_b is not None:
404
+ grad_b[start:end] = gy_slice.sum(dim=0)
405
+
406
+ if ctx.sparse_dx:
407
+ grad_x_flat += gy_slice @ w_slice
408
+
409
+ return grad_x_flat.reshape(x.shape), grad_w, grad_b, None, None, None
410
+
411
+
412
+ class ChunkedMaskedLinearTriton(torch.autograd.Function):
413
+ @staticmethod
414
+ def forward(
415
+ ctx,
416
+ x: torch.Tensor,
417
+ weight: torch.Tensor,
418
+ bias: Optional[torch.Tensor],
419
+ active_chunks: torch.Tensor,
420
+ chunk_size: int,
421
+ sparse_dx: bool,
422
+ ) -> torch.Tensor:
423
+ ctx.save_for_backward(x, weight, active_chunks)
424
+ ctx.has_bias = bias is not None
425
+ ctx.sparse_dx = bool(sparse_dx)
426
+ ctx.chunk_size = int(chunk_size)
427
+ return F.linear(x, weight, bias)
428
+
429
+ @staticmethod
430
+ def backward(ctx, grad_y: torch.Tensor):
431
+ x, weight, active_chunks = ctx.saved_tensors
432
+ chunk_size = ctx.chunk_size
433
+
434
+ x_shape = x.shape
435
+ x_flat = x.reshape(-1, x.shape[-1]).contiguous()
436
+ gy_flat = grad_y.reshape(-1, grad_y.shape[-1]).contiguous()
437
+
438
+ grad_w, grad_b = triton_sparse_bwd_dW_db(
439
+ x_flat=x_flat,
440
+ gy_flat=gy_flat,
441
+ active_chunks=active_chunks,
442
+ chunk_size=chunk_size,
443
+ d_out=weight.shape[0],
444
+ has_bias=ctx.has_bias,
445
+ )
446
+
447
+ if ctx.sparse_dx:
448
+ grad_x_flat = triton_sparse_bwd_dX(
449
+ gy_flat=gy_flat,
450
+ weight=weight.contiguous(),
451
+ active_chunks=active_chunks,
452
+ chunk_size=chunk_size,
453
+ M=x_flat.shape[0],
454
+ d_in=x_flat.shape[1],
455
+ )
456
+ else:
457
+ grad_x_flat = gy_flat @ weight
458
+
459
+ grad_x = grad_x_flat.reshape(x_shape)
460
+ return grad_x, grad_w, grad_b, None, None, None
461
+
462
+
463
+ class SparseLinear(nn.Linear):
464
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
465
+ super().__init__(in_features, out_features, bias=bias)
466
+ self.sparse_enabled = False
467
+ self.sparse_dx = False
468
+ self.active_chunks: Optional[torch.Tensor] = None
469
+ self.chunk_size = 64
470
+ self.kernel_backend: KernelBackend = "triton"
471
+
472
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
473
+ if not self.sparse_enabled or self.active_chunks is None:
474
+ return F.linear(x, self.weight, self.bias)
475
+
476
+ if self.kernel_backend == "triton":
477
+ return ChunkedMaskedLinearTriton.apply(
478
+ x, self.weight, self.bias, self.active_chunks, self.chunk_size, self.sparse_dx
479
+ )
480
+
481
+ return ChunkedMaskedLinearTorch.apply(
482
+ x, self.weight, self.bias, self.active_chunks, self.chunk_size, self.sparse_dx
483
+ )
484
+
485
+
486
+ # ================================================================
487
+ # MiniGPT
488
+ # ================================================================
489
+
490
+ class CausalSelfAttention(nn.Module):
491
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
492
+ super().__init__()
493
+ assert n_embd % n_head == 0
494
+ self.n_head = n_head
495
+ self.head_dim = n_embd // n_head
496
+ self.c_attn = SparseLinear(n_embd, 3 * n_embd)
497
+ self.c_proj = SparseLinear(n_embd, n_embd)
498
+ self.dropout = nn.Dropout(dropout)
499
+ self.register_buffer(
500
+ "mask",
501
+ torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size),
502
+ )
503
+
504
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
505
+ B, T, C = x.shape
506
+ qkv = self.c_attn(x)
507
+ q, k, v = qkv.split(C, dim=2)
508
+
509
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
510
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
511
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
512
+
513
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
514
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
515
+ att = F.softmax(att, dim=-1)
516
+ att = self.dropout(att)
517
+
518
+ y = att @ v
519
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
520
+ return self.c_proj(y)
521
+
522
+
523
+ class FeedForward(nn.Module):
524
+ def __init__(self, n_embd: int, dropout: float):
525
+ super().__init__()
526
+ self.c_fc = SparseLinear(n_embd, 4 * n_embd)
527
+ self.c_proj = SparseLinear(4 * n_embd, n_embd)
528
+ self.dropout = nn.Dropout(dropout)
529
+
530
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
531
+ return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
532
+
533
+
534
+ class Block(nn.Module):
535
+ def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
536
+ super().__init__()
537
+ self.ln1 = nn.LayerNorm(n_embd)
538
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
539
+ self.ln2 = nn.LayerNorm(n_embd)
540
+ self.mlp = FeedForward(n_embd, dropout)
541
+
542
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
543
+ x = x + self.attn(self.ln1(x))
544
+ x = x + self.mlp(self.ln2(x))
545
+ return x
546
+
547
+
548
+ class MiniGPT(nn.Module):
549
+ def __init__(
550
+ self,
551
+ vocab_size: int,
552
+ block_size: int,
553
+ n_layer: int,
554
+ n_head: int,
555
+ n_embd: int,
556
+ dropout: float,
557
+ ):
558
+ super().__init__()
559
+ self.block_size = block_size
560
+ self.tok_emb = nn.Embedding(vocab_size, n_embd)
561
+ self.pos_emb = nn.Embedding(block_size, n_embd)
562
+ self.blocks = nn.Sequential(
563
+ *[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]
564
+ )
565
+ self.ln_f = nn.LayerNorm(n_embd)
566
+ self.lm_head = nn.Linear(n_embd, vocab_size)
567
+
568
+ def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
569
+ B, T = idx.shape
570
+ pos = torch.arange(T, device=idx.device)
571
+ x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
572
+ x = self.blocks(x)
573
+ x = self.ln_f(x)
574
+ logits = self.lm_head(x)
575
+
576
+ loss = None
577
+ if targets is not None:
578
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
579
+ return logits, loss
580
+
581
+
582
+ def get_sparse_linears(model: nn.Module) -> List[SparseLinear]:
583
+ return [m for m in model.modules() if isinstance(m, SparseLinear)]
584
+
585
+
586
+ # ================================================================
587
+ # Scheduler
588
+ # ================================================================
589
+
590
+ class FastChunkScheduler:
591
+ def __init__(
592
+ self,
593
+ model: nn.Module,
594
+ scheduler: Scheduler,
595
+ target_fraction: float,
596
+ chunk_size: int,
597
+ device: str,
598
+ mass_beta: float = 0.95,
599
+ similarity_history: int = 128,
600
+ min_similarity_history: int = 8,
601
+ knn_k: int = 3,
602
+ ):
603
+ self.model = model
604
+ self.scheduler = scheduler
605
+ self.target_fraction = target_fraction
606
+ self.chunk_size = chunk_size
607
+ self.device = device
608
+ self.mass_beta = mass_beta
609
+ self.similarity_history = similarity_history
610
+ self.min_similarity_history = min_similarity_history
611
+ self.knn_k = knn_k
612
+
613
+ self.linears = get_sparse_linears(model)
614
+ self.module_to_chunk_ids: Dict[nn.Module, torch.Tensor] = {}
615
+ self.module_to_local_ids: Dict[nn.Module, torch.Tensor] = {}
616
+
617
+ offset = 0
618
+ for m in self.linears:
619
+ m.chunk_size = chunk_size
620
+ n_chunks = m.out_features // chunk_size
621
+ assert m.out_features % chunk_size == 0, (
622
+ f"out_features {m.out_features} not divisible by chunk_size {chunk_size}"
623
+ )
624
+
625
+ ids = torch.arange(offset, offset + n_chunks, device=device)
626
+ local = torch.arange(n_chunks, device=device)
627
+
628
+ self.module_to_chunk_ids[m] = ids
629
+ self.module_to_local_ids[m] = local
630
+ offset += n_chunks
631
+
632
+ self.n_chunks = offset
633
+ self.predicted_mass = torch.zeros(self.n_chunks, device=device)
634
+ self.mass_history: List[torch.Tensor] = []
635
+ self.similarity: Optional[torch.Tensor] = None
636
+
637
+ self.active_chunks = torch.zeros(self.n_chunks, dtype=torch.bool, device=device)
638
+ self.sensor_scores = torch.zeros(self.n_chunks, device=device)
639
+
640
+ def current_fraction(self, step: int, warmup_steps: int, anneal_steps: int) -> float:
641
+ if self.scheduler == "dense":
642
+ return 1.0
643
+ if step < warmup_steps:
644
+ return 1.0
645
+ if anneal_steps > 0 and step < warmup_steps + anneal_steps:
646
+ progress = (step - warmup_steps) / anneal_steps
647
+ cosine_mult = 0.5 * (1.0 + math.cos(math.pi * progress))
648
+ return self.target_fraction + (1.0 - self.target_fraction) * cosine_mult
649
+ return self.target_fraction
650
+
651
+ def choose_active(self, step: int, warmup_steps: int, anneal_steps: int) -> torch.Tensor:
652
+ frac = self.current_fraction(step, warmup_steps, anneal_steps)
653
+
654
+ if frac >= 0.999 or self.scheduler == "dense":
655
+ self.active_chunks.fill_(True)
656
+ self.install_local_masks()
657
+ return self.active_chunks
658
+
659
+ k = max(1, int(frac * self.n_chunks))
660
+ self.active_chunks.fill_(False)
661
+
662
+ if self.scheduler == "random":
663
+ idx = torch.randperm(self.n_chunks, device=self.device)[:k]
664
+
665
+ elif self.scheduler == "ema_topk":
666
+ scores = self.predicted_mass + 1e-9 * torch.rand_like(self.predicted_mass)
667
+ idx = torch.topk(scores, k=k).indices
668
+
669
+ elif self.scheduler == "knn_scheduler":
670
+ base = self.sensor_scores
671
+ if torch.count_nonzero(base).item() == 0:
672
+ base = self.predicted_mass
673
+ scores = base + 1e-9 * torch.rand_like(base)
674
+ idx = torch.topk(scores, k=k).indices
675
+
676
+ else:
677
+ raise ValueError(f"Unknown scheduler: {self.scheduler}")
678
+
679
+ self.active_chunks[idx] = True
680
+ self.install_local_masks()
681
+ return self.active_chunks
682
+
683
+ def install_local_masks(self) -> None:
684
+ for m, global_ids in self.module_to_chunk_ids.items():
685
+ local = self.module_to_local_ids[m]
686
+ m.active_chunks = local[self.active_chunks[global_ids]]
687
+
688
+ @torch.no_grad()
689
+ def update_from_active_gradients(self, step: int, warmup_steps: int) -> torch.Tensor:
690
+ current_mass = torch.zeros_like(self.predicted_mass)
691
+
692
+ for m, ids in self.module_to_chunk_ids.items():
693
+ if m.weight.grad is None:
694
+ continue
695
+
696
+ w_sq = m.weight.grad.square().view(len(ids), self.chunk_size, -1).sum(dim=(1, 2))
697
+ if m.bias is not None and m.bias.grad is not None:
698
+ w_sq += m.bias.grad.square().view(len(ids), self.chunk_size).sum(dim=1)
699
+
700
+ current_mass[ids] = torch.sqrt(w_sq + 1e-30)
701
+
702
+ observed = self.active_chunks
703
+ never_seen = observed & (self.predicted_mass == 0)
704
+ already_seen = observed & ~never_seen
705
+
706
+ self.predicted_mass[never_seen] = current_mass[never_seen]
707
+ self.predicted_mass[already_seen] = (
708
+ self.mass_beta * self.predicted_mass[already_seen]
709
+ + (1.0 - self.mass_beta) * current_mass[already_seen]
710
+ )
711
+
712
+ if step < warmup_steps:
713
+ self.mass_history.append(current_mass.detach().clone())
714
+ if len(self.mass_history) > self.similarity_history:
715
+ self.mass_history = self.mass_history[-self.similarity_history :]
716
+ if len(self.mass_history) >= self.min_similarity_history:
717
+ self.similarity = self.build_similarity()
718
+
719
+ if self.scheduler == "knn_scheduler":
720
+ self.sensor_scores = self.knn_scores(self.active_chunks, current_mass)
721
+ else:
722
+ self.sensor_scores = self.predicted_mass.clone()
723
+
724
+ return current_mass
725
+
726
+ def build_similarity(self) -> torch.Tensor:
727
+ H = torch.stack(self.mass_history, dim=0)
728
+ H = H - H.mean(dim=0, keepdim=True)
729
+ H = H / (H.std(dim=0, keepdim=True) + 1e-6)
730
+
731
+ S = (H.T @ H) / max(1, H.shape[0] - 1)
732
+ S = torch.clamp(S, min=0.0)
733
+ S.fill_diagonal_(0.0)
734
+
735
+ allowed = torch.zeros_like(S, dtype=torch.bool)
736
+ for _, ids in self.module_to_chunk_ids.items():
737
+ allowed[ids[:, None], ids[None, :]] = True
738
+
739
+ return torch.where(allowed, S, torch.zeros_like(S))
740
+
741
+ def knn_scores(self, active_mask: torch.Tensor, current_mass: torch.Tensor) -> torch.Tensor:
742
+ if self.similarity is None:
743
+ return self.predicted_mass.clone()
744
+
745
+ scores = self.predicted_mass.clone()
746
+ scores[active_mask] = current_mass[active_mask]
747
+
748
+ active_idx = torch.nonzero(active_mask, as_tuple=False).flatten()
749
+ inactive_idx = torch.nonzero(~active_mask, as_tuple=False).flatten()
750
+
751
+ if active_idx.numel() == 0:
752
+ return scores
753
+
754
+ S = self.similarity
755
+ for i in inactive_idx.tolist():
756
+ weights = S[i, active_idx]
757
+ if weights.sum() <= 1e-12:
758
+ continue
759
+
760
+ kk = min(self.knn_k, weights.numel())
761
+ top = torch.topk(weights, k=kk)
762
+ w = top.values
763
+ aidx = active_idx[top.indices]
764
+ scores[i] = (w * current_mass[aidx]).sum() / (w.sum() + 1e-12)
765
+
766
+ return scores
767
+
768
+
769
+ # ================================================================
770
+ # Chunked Adam
771
+ # ================================================================
772
+
773
+ class ChunkedAdam:
774
+ def __init__(self, model: nn.Module, lr: float = 3e-4, chunk_size: int = 64):
775
+ self.model = model
776
+ self.lr = lr
777
+ self.chunk_size = chunk_size
778
+ self.state: Dict[torch.nn.Parameter, Dict[str, torch.Tensor]] = {}
779
+
780
+ self.param_to_sparse_module: Dict[torch.nn.Parameter, SparseLinear] = {}
781
+ for m in get_sparse_linears(model):
782
+ if m.weight is not None:
783
+ self.param_to_sparse_module[m.weight] = m
784
+ if m.bias is not None:
785
+ self.param_to_sparse_module[m.bias] = m
786
+
787
+ def zero_grad(self):
788
+ for p in self.model.parameters():
789
+ p.grad = None
790
+
791
+ @torch.no_grad()
792
+ def step(self):
793
+ for p in self.model.parameters():
794
+ if p.grad is None:
795
+ continue
796
+
797
+ if p not in self.state:
798
+ self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
799
+
800
+ exp_avg = self.state[p]["m"]
801
+ exp_avg_sq = self.state[p]["v"]
802
+
803
+ sparse_module = self.param_to_sparse_module.get(p)
804
+ active_chunks = getattr(sparse_module, "active_chunks", None) if sparse_module else None
805
+
806
+ if active_chunks is None:
807
+ exp_avg.mul_(0.9).add_(p.grad, alpha=0.1)
808
+ exp_avg_sq.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
809
+ p.sub_(exp_avg / (torch.sqrt(exp_avg_sq) + 1e-8), alpha=self.lr)
810
+ else:
811
+ for local_c in active_chunks.tolist():
812
+ start = int(local_c) * self.chunk_size
813
+ end = start + self.chunk_size
814
+
815
+ p_chunk = p[start:end]
816
+ g_chunk = p.grad[start:end]
817
+ m_chunk = exp_avg[start:end]
818
+ v_chunk = exp_avg_sq[start:end]
819
+
820
+ m_chunk.mul_(0.9).add_(g_chunk, alpha=0.1)
821
+ v_chunk.mul_(0.999).addcmul_(g_chunk, g_chunk, value=0.001)
822
+ p_chunk.sub_(m_chunk / (torch.sqrt(v_chunk) + 1e-8), alpha=self.lr)
823
+
824
+
825
+ # ================================================================
826
+ # Training
827
+ # ================================================================
828
+
829
+ def evaluate(model: nn.Module, corpus: CharCorpus, batch_size: int, seed: int) -> float:
830
+ model.eval()
831
+ with torch.no_grad():
832
+ x, y = corpus.get_batch("val", batch_size, generator=make_cpu_generator(seed))
833
+ _, loss = model(x, y)
834
+ model.train()
835
+ return float(loss.item())
836
+
837
+
838
+ def run_one(
839
+ scheduler_name: Scheduler,
840
+ mode: BackwardMode,
841
+ kernel_backend: KernelBackend,
842
+ device: str,
843
+ steps: int,
844
+ batch_size: int,
845
+ block_size: int,
846
+ n_layer: int,
847
+ n_head: int,
848
+ n_embd: int,
849
+ chunk_size: int,
850
+ active_fraction: float,
851
+ warmup_steps: int,
852
+ anneal_steps: int,
853
+ benchmark_sync: bool,
854
+ ) -> Dict[str, float]:
855
+ set_seed(42)
856
+
857
+ corpus = CharCorpus(make_synthetic_corpus(), block_size, device)
858
+ model = MiniGPT(corpus.vocab_size, block_size, n_layer, n_head, n_embd, 0.0).to(device)
859
+
860
+ for m in get_sparse_linears(model):
861
+ m.chunk_size = chunk_size
862
+ m.kernel_backend = kernel_backend
863
+
864
+ sched = FastChunkScheduler(
865
+ model=model,
866
+ scheduler=scheduler_name,
867
+ target_fraction=active_fraction,
868
+ chunk_size=chunk_size,
869
+ device=device,
870
+ )
871
+ opt = ChunkedAdam(model, lr=3e-4, chunk_size=chunk_size)
872
+
873
+ measured_steps = steps
874
+
875
+ if benchmark_sync:
876
+ sync_device(device)
877
+ t0 = time.perf_counter()
878
+
879
+ for step in range(steps):
880
+ if step == warmup_steps + anneal_steps:
881
+ if benchmark_sync:
882
+ sync_device(device)
883
+ t0 = time.perf_counter()
884
+ measured_steps = steps - step
885
+
886
+ if scheduler_name == "dense" or mode == "dense_baseline":
887
+ for m in get_sparse_linears(model):
888
+ m.sparse_enabled = False
889
+ m.active_chunks = None
890
+ else:
891
+ sched.choose_active(step, warmup_steps=warmup_steps, anneal_steps=anneal_steps)
892
+ for m in get_sparse_linears(model):
893
+ m.sparse_enabled = True
894
+ m.sparse_dx = mode == "sparse_dW_sparse_dX"
895
+
896
+ x, y = corpus.get_batch("train", batch_size, generator=make_cpu_generator(step))
897
+
898
+ opt.zero_grad()
899
+ _, loss = model(x, y)
900
+ loss.backward()
901
+
902
+ if scheduler_name != "dense" and mode != "dense_baseline":
903
+ sched.update_from_active_gradients(step=step, warmup_steps=warmup_steps)
904
+
905
+ opt.step()
906
+
907
+ if benchmark_sync:
908
+ sync_device(device)
909
+ elapsed = time.perf_counter() - t0
910
+
911
+ val_loss = evaluate(model, corpus, batch_size, seed=12345)
912
+ return {"val": val_loss, "ms": 1000.0 * elapsed / max(1, measured_steps)}
913
+
914
+
915
+ def run_correctness_smoke(
916
+ device: str,
917
+ chunk_size: int = 64,
918
+ dtype: torch.dtype = torch.float32,
919
+ ) -> None:
920
+ if device != "cuda":
921
+ print("Skipping Triton correctness smoke: device is not cuda")
922
+ return
923
+ if not TRITON_AVAILABLE:
924
+ print("Skipping Triton correctness smoke: Triton not available")
925
+ return
926
+
927
+ print("\nTriton sparse Linear correctness smoke")
928
+ print("-" * 60)
929
+
930
+ torch.manual_seed(123)
931
+ shapes = [(512, 2048), (1024, 4096)]
932
+
933
+ for d_in, d_out in shapes:
934
+ M = 512
935
+ n_chunks = d_out // chunk_size
936
+ n_active = max(1, int(0.1 * n_chunks))
937
+ active = torch.randperm(n_chunks, device=device)[:n_active].sort().values
938
+
939
+ x = torch.randn(M, d_in, device=device, dtype=dtype)
940
+ w = torch.randn(d_out, d_in, device=device, dtype=dtype)
941
+ gy = torch.randn(M, d_out, device=device, dtype=dtype)
942
+
943
+ ref_dw = torch.zeros_like(w)
944
+ ref_db = torch.zeros(d_out, device=device, dtype=dtype)
945
+ ref_dx = torch.zeros_like(x)
946
+
947
+ for c in active.tolist():
948
+ s = c * chunk_size
949
+ e = s + chunk_size
950
+ ref_dw[s:e] = gy[:, s:e].transpose(0, 1) @ x
951
+ ref_db[s:e] = gy[:, s:e].sum(0)
952
+ ref_dx += gy[:, s:e] @ w[s:e]
953
+
954
+ tri_dw, tri_db = triton_sparse_bwd_dW_db(x, gy, active, chunk_size, d_out, True)
955
+ tri_dx = triton_sparse_bwd_dX(gy, w, active, chunk_size, M, d_in)
956
+
957
+ dw_err = float((tri_dw - ref_dw).abs().max().item())
958
+ db_err = float((tri_db - ref_db).abs().max().item())
959
+ dx_err = float((tri_dx - ref_dx).abs().max().item())
960
+
961
+ print(f"d_in={d_in:4d} d_out={d_out:5d}: dW={dw_err:.6f} dB={db_err:.6f} dX={dx_err:.6f}")
962
+
963
+
964
+ def main() -> None:
965
+ parser = argparse.ArgumentParser()
966
+ parser.add_argument("--steps", type=int, default=500)
967
+ parser.add_argument("--batch_size", type=int, default=16)
968
+ parser.add_argument("--block_size", type=int, default=256)
969
+ parser.add_argument("--n_layer", type=int, default=4)
970
+ parser.add_argument("--n_head", type=int, default=16)
971
+ parser.add_argument("--n_embd", type=int, default=1024)
972
+ parser.add_argument("--chunk_size", type=int, default=64)
973
+ parser.add_argument("--active_fraction", type=float, default=0.10)
974
+ parser.add_argument("--warmup_steps", type=int, default=25)
975
+ parser.add_argument("--anneal_steps", type=int, default=150)
976
+ parser.add_argument("--device", type=str, default="cuda")
977
+ parser.add_argument("--kernel_backend", type=str, default="triton", choices=["triton", "torch"])
978
+ parser.add_argument("--benchmark_sync", action="store_true")
979
+ parser.add_argument("--skip_correctness", action="store_true")
980
+ args = parser.parse_args()
981
+
982
+ if args.kernel_backend == "triton":
983
+ if args.device != "cuda":
984
+ raise RuntimeError("--kernel_backend triton requires --device cuda")
985
+ if not TRITON_AVAILABLE:
986
+ raise RuntimeError("Triton is not available")
987
+
988
+ if not args.skip_correctness:
989
+ run_correctness_smoke(device=args.device, chunk_size=args.chunk_size)
990
+
991
+ runs: List[Tuple[str, Scheduler, BackwardMode]] = [
992
+ ("dense", "dense", "dense_baseline"),
993
+ ("ema_full_dX", "ema_topk", "sparse_dW_full_dX"),
994
+ ("knn_full_dX", "knn_scheduler", "sparse_dW_full_dX"),
995
+ ("random_full_dX", "random", "sparse_dW_full_dX"),
996
+ ("ema_sparse_dX", "ema_topk", "sparse_dW_sparse_dX"),
997
+ ("knn_sparse_dX", "knn_scheduler", "sparse_dW_sparse_dX"),
998
+ ("random_sparse_dX", "random", "sparse_dW_sparse_dX"),
999
+ ]
1000
+
1001
+ print("\nTriton-backed fast chunked sparse backward with KNN scheduler")
1002
+ print(f"device={args.device} backend={args.kernel_backend} triton_available={TRITON_AVAILABLE}")
1003
+ print(f"steps={args.steps} d={args.n_embd} layers={args.n_layer}")
1004
+ print(f"batch={args.batch_size} block={args.block_size} chunk={args.chunk_size}")
1005
+ print(f"active={args.active_fraction} warmup={args.warmup_steps} anneal={args.anneal_steps}\n")
1006
+ print(f"{'run':>18s} | {'val':>8s} | {'ms/step':>8s} | {'speedup':>8s}")
1007
+ print("-" * 58)
1008
+
1009
+ dense_ms: Optional[float] = None
1010
+
1011
+ for label, scheduler, mode in runs:
1012
+ result = run_one(
1013
+ scheduler_name=scheduler,
1014
+ mode=mode,
1015
+ kernel_backend=args.kernel_backend, # type: ignore[arg-type]
1016
+ device=args.device,
1017
+ steps=args.steps,
1018
+ batch_size=args.batch_size,
1019
+ block_size=args.block_size,
1020
+ n_layer=args.n_layer,
1021
+ n_head=args.n_head,
1022
+ n_embd=args.n_embd,
1023
+ chunk_size=args.chunk_size,
1024
+ active_fraction=args.active_fraction,
1025
+ warmup_steps=args.warmup_steps,
1026
+ anneal_steps=args.anneal_steps,
1027
+ benchmark_sync=args.benchmark_sync,
1028
+ )
1029
+
1030
+ if label == "dense":
1031
+ dense_ms = result["ms"]
1032
+
1033
+ speedup = dense_ms / result["ms"] if dense_ms is not None else float("nan")
1034
+
1035
+ print(
1036
+ f"{label:>18s} | "
1037
+ f"{result['val']:8.4f} | "
1038
+ f"{result['ms']:8.2f} | "
1039
+ f"{speedup:8.3f}"
1040
+ )
1041
+
1042
+
1043
+ if __name__ == "__main__":
1044
+ main()