Commit ·
bc1b8eb
1
Parent(s): 514b330
Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.
Browse files- e2e_bench.py → experiments/e2e_bench.py +0 -0
- e2e_full.py → experiments/e2e_full.py +0 -0
- experiments/pyproject.toml +19 -0
- experiments/sparse_linear_v10_metal/README.md +62 -0
- experiments/sparse_linear_v10_metal/setup.py +45 -0
- experiments/sparse_linear_v10_metal/sparse_linear.metal +65 -0
- experiments/sparse_linear_v10_metal/sparse_linear_ops.mm +158 -0
- experiments/sparse_linear_v10_metal/sparse_transformer_v10.py +1112 -0
- experiments/sparse_linear_v11_gather_vs_metal/README.md +62 -0
- experiments/sparse_linear_v11_gather_vs_metal/input.txt +0 -0
- experiments/sparse_linear_v11_gather_vs_metal/setup.py +15 -0
- experiments/sparse_linear_v11_gather_vs_metal/sparse_linear.metal +62 -0
- experiments/sparse_linear_v11_gather_vs_metal/sparse_linear_ops.mm +168 -0
- experiments/sparse_linear_v11_gather_vs_metal/sparse_transformer_v11.py +406 -0
- experiments/sparse_linear_v11_gather_vs_metal/sparse_transformer_v13.py +419 -0
- experiments/sparse_linear_v11_gather_vs_metal/tiny.py +441 -0
- experiments/sparse_transformer_v15_inactive_prediction.py +729 -0
- experiments/sparse_transformer_v16_sensor_scheduler.py +677 -0
- experiments/sparse_transformer_v17_radar_scheduler.py +725 -0
- experiments/sparse_transformer_v6.py +596 -0
- experiments/sparse_transformer_v7.py +780 -0
- experiments/sparse_transformer_v8.py +943 -0
- experiments/sparse_transformer_v9.py +1042 -0
- experiments/surprise_topk_gradient_prototype-v2.py +418 -0
- experiments/surprise_topk_gradient_prototype-v3.py +487 -0
- experiments/surprise_topk_gradient_prototype-v4.py +571 -0
- experiments/surprise_topk_gradient_prototype-v5.py +563 -0
- experiments/surprise_topk_gradient_prototype.py +426 -0
- triton_sparse.py → experiments/triton_sparse.py +0 -0
- triton_v2.py → experiments/triton_v2.py +0 -0
- experiments/uv.lock +0 -0
- paper/main.tex +307 -0
- sparse_transformer_v18_fast_knn.py +459 -0
- sparse_transformer_v18_fast_knn_triton.py +1044 -0
e2e_bench.py → experiments/e2e_bench.py
RENAMED
|
File without changes
|
e2e_full.py → experiments/e2e_full.py
RENAMED
|
File without changes
|
experiments/pyproject.toml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "surprise-topk-gradient"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "Prototype for surprise Top-K gradient training experiments"
|
| 9 |
+
requires-python = ">=3.10"
|
| 10 |
+
dependencies = [
|
| 11 |
+
"numpy>=1.24",
|
| 12 |
+
"torch>=2.0",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
[project.scripts]
|
| 16 |
+
surprise-topk-gradient = "surprise_topk_gradient_prototype:main"
|
| 17 |
+
|
| 18 |
+
[tool.setuptools]
|
| 19 |
+
py-modules = ["surprise_topk_gradient_prototype"]
|
experiments/sparse_linear_v10_metal/README.md
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Sparse Transformer v10: Metal-backed active-row Linear backward benchmark
|
| 2 |
+
|
| 3 |
+
This bundle is the first empirical optimization test rather than only a correctness prototype.
|
| 4 |
+
It compares dense Transformer training against sparse active-row backward variants and reports
|
| 5 |
+
validation loss plus wall-clock metrics (`step_ms`, `tokens_per_s`).
|
| 6 |
+
|
| 7 |
+
## Files
|
| 8 |
+
|
| 9 |
+
- `sparse_transformer_v10.py` — training + benchmark harness.
|
| 10 |
+
- `sparse_linear.metal` — Metal kernels for active-row `dW/db` and sparse `dX`.
|
| 11 |
+
- `sparse_linear_ops.mm` — PyTorch/MPS extension glue.
|
| 12 |
+
- `setup.py` — builds the extension and compiles the `.metallib`.
|
| 13 |
+
|
| 14 |
+
## Install the Metal extension
|
| 15 |
+
|
| 16 |
+
From this directory on macOS with PyTorch MPS available:
|
| 17 |
+
|
| 18 |
+
```bash
|
| 19 |
+
python3 -m pip install -e .
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
This uses `xcrun metal` and `xcrun metallib`, so Xcode command-line tools are required.
|
| 23 |
+
|
| 24 |
+
## Correctness/sanity run with PyTorch fallback
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
python3 sparse_transformer_v10.py \
|
| 28 |
+
--device mps \
|
| 29 |
+
--steps 2000 \
|
| 30 |
+
--active_fractions 0.05 0.02 \
|
| 31 |
+
--warmup_steps_list 5 \
|
| 32 |
+
--policies predicted_magnitude random \
|
| 33 |
+
--backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
|
| 34 |
+
--audit_every 0 \
|
| 35 |
+
--kernel_backend torch \
|
| 36 |
+
--benchmark_sync
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
## Metal benchmark run
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
python3 sparse_transformer_v10.py \
|
| 43 |
+
--device mps \
|
| 44 |
+
--steps 2000 \
|
| 45 |
+
--active_fractions 0.05 0.02 \
|
| 46 |
+
--warmup_steps_list 5 \
|
| 47 |
+
--policies predicted_magnitude random \
|
| 48 |
+
--backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
|
| 49 |
+
--audit_every 0 \
|
| 50 |
+
--kernel_backend metal \
|
| 51 |
+
--benchmark_sync
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
## Empirical pass/fail criteria
|
| 55 |
+
|
| 56 |
+
A genuine optimization would show:
|
| 57 |
+
|
| 58 |
+
1. `predicted_magnitude` validation loss close to the PyTorch fallback v9/v10 result.
|
| 59 |
+
2. `predicted_magnitude` much better than `random` at the same active fraction.
|
| 60 |
+
3. `--kernel_backend metal` has lower `step_ms` or higher `tokens_per_s` than `--kernel_backend torch` and ideally dense baseline.
|
| 61 |
+
|
| 62 |
+
The first Metal kernels are intentionally simple and fp32-only. They are meant to prove or disprove the acceleration path before investing in tiled half-precision kernels.
|
experiments/sparse_linear_v10_metal/setup.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import shutil
|
| 2 |
+
import subprocess
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import torch
|
| 5 |
+
from setuptools import setup
|
| 6 |
+
from torch.utils.cpp_extension import BuildExtension, CppExtension
|
| 7 |
+
|
| 8 |
+
ROOT = Path(__file__).parent.resolve()
|
| 9 |
+
# PyTorch dylibs (libc10, libtorch, …) are @rpath-linked; embed torch/lib so import works from any cwd.
|
| 10 |
+
_TORCH_LIB = Path(torch.__file__).resolve().parent / "lib"
|
| 11 |
+
METAL_SRC = ROOT / "sparse_linear.metal"
|
| 12 |
+
AIR = ROOT / "sparse_linear.air"
|
| 13 |
+
METALLIB = ROOT / "sparse_linear_ops.metallib"
|
| 14 |
+
|
| 15 |
+
class MetalBuildExt(BuildExtension):
|
| 16 |
+
def run(self):
|
| 17 |
+
if shutil.which("xcrun") is None:
|
| 18 |
+
raise RuntimeError("xcrun not found. Install Xcode command line tools.")
|
| 19 |
+
subprocess.check_call(["xcrun", "-sdk", "macosx", "metal", "-c", str(METAL_SRC), "-o", str(AIR)])
|
| 20 |
+
subprocess.check_call(["xcrun", "-sdk", "macosx", "metallib", str(AIR), "-o", str(METALLIB)])
|
| 21 |
+
super().run()
|
| 22 |
+
# Copy metallib next to the built extension .so.
|
| 23 |
+
build_lib = Path(self.build_lib)
|
| 24 |
+
for so in build_lib.rglob("sparse_linear_metal*.so"):
|
| 25 |
+
shutil.copy2(METALLIB, so.parent / METALLIB.name)
|
| 26 |
+
|
| 27 |
+
setup(
|
| 28 |
+
name="sparse_linear_metal",
|
| 29 |
+
version="0.1.0",
|
| 30 |
+
ext_modules=[
|
| 31 |
+
CppExtension(
|
| 32 |
+
name="sparse_linear_metal",
|
| 33 |
+
sources=["sparse_linear_ops.mm"],
|
| 34 |
+
extra_compile_args={"cxx": ["-std=c++17", "-ObjC++", "-fobjc-arc"]},
|
| 35 |
+
extra_link_args=[
|
| 36 |
+
"-framework",
|
| 37 |
+
"Metal",
|
| 38 |
+
"-framework",
|
| 39 |
+
"Foundation",
|
| 40 |
+
f"-Wl,-rpath,{_TORCH_LIB}",
|
| 41 |
+
],
|
| 42 |
+
)
|
| 43 |
+
],
|
| 44 |
+
cmdclass={"build_ext": MetalBuildExt},
|
| 45 |
+
)
|
experiments/sparse_linear_v10_metal/sparse_linear.metal
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <metal_stdlib>
|
| 2 |
+
using namespace metal;
|
| 3 |
+
|
| 4 |
+
struct SparseLinearParams {
|
| 5 |
+
uint32_t N;
|
| 6 |
+
uint32_t In;
|
| 7 |
+
uint32_t Out;
|
| 8 |
+
uint32_t dummy; // alignment padding
|
| 9 |
+
};
|
| 10 |
+
|
| 11 |
+
kernel void sparse_linear_grad_w_float(
|
| 12 |
+
device const float* x [[buffer(0)]], // [N, In]
|
| 13 |
+
device const float* gy [[buffer(1)]], // [N, Out]
|
| 14 |
+
device const bool* active_mask [[buffer(2)]], // [Out] boolean mask
|
| 15 |
+
device float* grad_w [[buffer(3)]], //[Out, In], zeroed by caller
|
| 16 |
+
device float* grad_b [[buffer(4)]], // [Out], zeroed by caller
|
| 17 |
+
constant SparseLinearParams& p [[buffer(5)]],
|
| 18 |
+
uint2 tid [[thread_position_in_grid]])
|
| 19 |
+
{
|
| 20 |
+
uint c = tid.x;
|
| 21 |
+
uint row = tid.y;
|
| 22 |
+
|
| 23 |
+
// Bounds check
|
| 24 |
+
if (row >= p.Out || c >= p.In) return;
|
| 25 |
+
|
| 26 |
+
// The magic: if the row isn't active, the thread exits instantly.
|
| 27 |
+
// No CPU sync required.
|
| 28 |
+
if (!active_mask[row]) return;
|
| 29 |
+
|
| 30 |
+
float acc = 0.0f;
|
| 31 |
+
for (uint n = 0; n < p.N; ++n) {
|
| 32 |
+
acc += gy[n * p.Out + row] * x[n * p.In + c];
|
| 33 |
+
}
|
| 34 |
+
grad_w[row * p.In + c] = acc;
|
| 35 |
+
|
| 36 |
+
// Bias calculation (could be optimized further, but fine for now)
|
| 37 |
+
if (c == 0) {
|
| 38 |
+
float bacc = 0.0f;
|
| 39 |
+
for (uint n = 0; n < p.N; ++n) {
|
| 40 |
+
bacc += gy[n * p.Out + row];
|
| 41 |
+
}
|
| 42 |
+
grad_b[row] = bacc;
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
kernel void sparse_linear_grad_x_float(
|
| 47 |
+
device const float* gy [[buffer(0)]], // [N, Out]
|
| 48 |
+
device const float* weight [[buffer(1)]], //[Out, In]
|
| 49 |
+
device const bool* active_mask [[buffer(2)]], // [Out] boolean mask
|
| 50 |
+
device float* grad_x [[buffer(3)]], // [N, In], zeroed by caller
|
| 51 |
+
constant SparseLinearParams& p [[buffer(4)]],
|
| 52 |
+
uint2 tid [[thread_position_in_grid]])
|
| 53 |
+
{
|
| 54 |
+
uint c = tid.x;
|
| 55 |
+
uint n = tid.y;
|
| 56 |
+
if (n >= p.N || c >= p.In) return;
|
| 57 |
+
|
| 58 |
+
float acc = 0.0f;
|
| 59 |
+
for (uint row = 0; row < p.Out; ++row) {
|
| 60 |
+
if (active_mask[row]) {
|
| 61 |
+
acc += gy[n * p.Out + row] * weight[row * p.In + c];
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
grad_x[n * p.In + c] = acc;
|
| 65 |
+
}
|
experiments/sparse_linear_v10_metal/sparse_linear_ops.mm
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include <ATen/ATen.h>
|
| 3 |
+
#include <ATen/mps/MPSStream.h>
|
| 4 |
+
#include <ATen/mps/MPSDevice.h>
|
| 5 |
+
#include <c10/util/Exception.h>
|
| 6 |
+
#include <filesystem>
|
| 7 |
+
#include <dlfcn.h>
|
| 8 |
+
#import <Metal/Metal.h>
|
| 9 |
+
#import <Foundation/Foundation.h>
|
| 10 |
+
#include <mutex>
|
| 11 |
+
|
| 12 |
+
namespace fs = std::filesystem;
|
| 13 |
+
|
| 14 |
+
namespace {
|
| 15 |
+
struct SparseLinearParams {
|
| 16 |
+
uint32_t N;
|
| 17 |
+
uint32_t In;
|
| 18 |
+
uint32_t Out;
|
| 19 |
+
uint32_t dummy;
|
| 20 |
+
};
|
| 21 |
+
|
| 22 |
+
static id<MTLLibrary> g_lib = nil;
|
| 23 |
+
static id<MTLComputePipelineState> g_pipeline_grad_w = nil;
|
| 24 |
+
static id<MTLComputePipelineState> g_pipeline_grad_x = nil;
|
| 25 |
+
static std::mutex g_mutex;
|
| 26 |
+
|
| 27 |
+
static std::string metallib_path_for_this_module() {
|
| 28 |
+
Dl_info info;
|
| 29 |
+
if (dladdr((void*)&metallib_path_for_this_module, &info) == 0 || info.dli_fname == nullptr) return std::string();
|
| 30 |
+
fs::path so_path(info.dli_fname);
|
| 31 |
+
return (so_path.parent_path() / "sparse_linear_ops.metallib").string();
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
static void ensure_library_locked(id<MTLDevice> device) {
|
| 35 |
+
if (g_lib != nil) return;
|
| 36 |
+
std::string path = metallib_path_for_this_module();
|
| 37 |
+
TORCH_CHECK(!path.empty(), "sparse_linear_ops: failed to locate extension path via dladdr");
|
| 38 |
+
NSString* ns_path = [NSString stringWithUTF8String:path.c_str()];
|
| 39 |
+
NSURL* url = [NSURL fileURLWithPath:ns_path];
|
| 40 |
+
NSError* err = nil;
|
| 41 |
+
g_lib = [device newLibraryWithURL:url error:&err];
|
| 42 |
+
if (g_lib == nil) {
|
| 43 |
+
const char* msg = err ? [[err localizedDescription] UTF8String] : "unknown error";
|
| 44 |
+
TORCH_CHECK(false, "sparse_linear_ops: failed to load metallib at ", path, ": ", msg);
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
static id<MTLComputePipelineState> ensure_pipeline(id<MTLDevice> device, id<MTLComputePipelineState>* pipeline, const char* fn_name) {
|
| 49 |
+
std::lock_guard<std::mutex> lock(g_mutex);
|
| 50 |
+
ensure_library_locked(device);
|
| 51 |
+
if (*pipeline != nil) return *pipeline;
|
| 52 |
+
NSString* ns_fn = [NSString stringWithUTF8String:fn_name];
|
| 53 |
+
id<MTLFunction> fn =[g_lib newFunctionWithName:ns_fn];
|
| 54 |
+
TORCH_CHECK(fn != nil, "sparse_linear_ops: function `", fn_name, "` not found in metallib");
|
| 55 |
+
NSError* err = nil;
|
| 56 |
+
*pipeline = [device newComputePipelineStateWithFunction:fn error:&err];
|
| 57 |
+
if (*pipeline == nil) {
|
| 58 |
+
const char* msg = err ? [[err localizedDescription] UTF8String] : "unknown error";
|
| 59 |
+
TORCH_CHECK(false, "sparse_linear_ops: failed to create pipeline for ", fn_name, ": ", msg);
|
| 60 |
+
}
|
| 61 |
+
return *pipeline;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
static inline id<MTLBuffer> storage_as_mtlbuffer(const at::Tensor& t) {
|
| 65 |
+
void* ctx = t.storage().data_ptr().get_context();
|
| 66 |
+
TORCH_CHECK(ctx != nullptr, "sparse_linear_ops: expected MPS tensor storage with MTLBuffer context");
|
| 67 |
+
return (__bridge id<MTLBuffer>)ctx;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
static inline NSUInteger storage_offset_bytes(const at::Tensor& t) {
|
| 71 |
+
return (NSUInteger)(t.storage_offset() * (int64_t)t.element_size());
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
static void check_mps_float_contig(const at::Tensor& t, const char* name) {
|
| 75 |
+
TORCH_CHECK(t.device().is_mps(), name, " must be on MPS");
|
| 76 |
+
TORCH_CHECK(t.dtype() == at::kFloat, name, " must be float32 for v12 kernel");
|
| 77 |
+
TORCH_CHECK(t.is_contiguous(), name, " must be contiguous");
|
| 78 |
+
}
|
| 79 |
+
} // namespace
|
| 80 |
+
|
| 81 |
+
std::vector<at::Tensor> sparse_linear_grad_wb(at::Tensor x2d, at::Tensor gy2d, at::Tensor active_mask) {
|
| 82 |
+
check_mps_float_contig(x2d, "x2d");
|
| 83 |
+
check_mps_float_contig(gy2d, "gy2d");
|
| 84 |
+
TORCH_CHECK(active_mask.device().is_mps(), "active_mask must be on MPS");
|
| 85 |
+
TORCH_CHECK(active_mask.dtype() == at::kBool, "active_mask must be bool");
|
| 86 |
+
TORCH_CHECK(active_mask.is_contiguous(), "active_mask must be contiguous");
|
| 87 |
+
|
| 88 |
+
int64_t N = x2d.size(0);
|
| 89 |
+
int64_t In = x2d.size(1);
|
| 90 |
+
int64_t Out = active_mask.size(0);
|
| 91 |
+
TORCH_CHECK(gy2d.size(1) == Out, "gy2d width must equal active_mask size");
|
| 92 |
+
|
| 93 |
+
auto grad_w = at::zeros({Out, In}, x2d.options());
|
| 94 |
+
auto grad_b = at::zeros({Out}, x2d.options());
|
| 95 |
+
|
| 96 |
+
id<MTLDevice> device = (id<MTLDevice>)at::mps::MPSDevice::getInstance()->device();
|
| 97 |
+
id<MTLComputePipelineState> pipeline = ensure_pipeline(device, &g_pipeline_grad_w, "sparse_linear_grad_w_float");
|
| 98 |
+
at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
|
| 99 |
+
|
| 100 |
+
id<MTLComputeCommandEncoder> encoder = (id<MTLComputeCommandEncoder>)stream->commandEncoder();[encoder setComputePipelineState:pipeline];
|
| 101 |
+
|
| 102 |
+
auto set_tensor = [&](const at::Tensor& t, int idx) {[encoder setBuffer:storage_as_mtlbuffer(t) offset:storage_offset_bytes(t) atIndex:(NSUInteger)idx];
|
| 103 |
+
};
|
| 104 |
+
set_tensor(x2d, 0);
|
| 105 |
+
set_tensor(gy2d, 1);
|
| 106 |
+
set_tensor(active_mask, 2);
|
| 107 |
+
set_tensor(grad_w, 3);
|
| 108 |
+
set_tensor(grad_b, 4);
|
| 109 |
+
|
| 110 |
+
SparseLinearParams prm{(uint32_t)N, (uint32_t)In, (uint32_t)Out, 0};
|
| 111 |
+
[encoder setBytes:&prm length:sizeof(SparseLinearParams) atIndex:5];
|
| 112 |
+
|
| 113 |
+
MTLSize tg = MTLSizeMake(16, 16, 1);
|
| 114 |
+
MTLSize grid = MTLSizeMake((NSUInteger)((In + 15) / 16), (NSUInteger)((Out + 15) / 16), 1);[encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg];
|
| 115 |
+
|
| 116 |
+
return {grad_w, grad_b};
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
at::Tensor sparse_linear_grad_x(at::Tensor gy2d, at::Tensor weight, at::Tensor active_mask) {
|
| 120 |
+
check_mps_float_contig(gy2d, "gy2d");
|
| 121 |
+
check_mps_float_contig(weight, "weight");
|
| 122 |
+
TORCH_CHECK(active_mask.device().is_mps(), "active_mask must be on MPS");
|
| 123 |
+
TORCH_CHECK(active_mask.dtype() == at::kBool, "active_mask must be bool");
|
| 124 |
+
TORCH_CHECK(active_mask.is_contiguous(), "active_mask must be contiguous");
|
| 125 |
+
|
| 126 |
+
int64_t N = gy2d.size(0);
|
| 127 |
+
int64_t Out = gy2d.size(1);
|
| 128 |
+
int64_t In = weight.size(1);
|
| 129 |
+
TORCH_CHECK(weight.size(0) == Out, "weight out_features must match gy2d width");
|
| 130 |
+
|
| 131 |
+
auto grad_x = at::zeros({N, In}, gy2d.options());
|
| 132 |
+
|
| 133 |
+
id<MTLDevice> device = (id<MTLDevice>)at::mps::MPSDevice::getInstance()->device();
|
| 134 |
+
id<MTLComputePipelineState> pipeline = ensure_pipeline(device, &g_pipeline_grad_x, "sparse_linear_grad_x_float");
|
| 135 |
+
at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
|
| 136 |
+
|
| 137 |
+
id<MTLComputeCommandEncoder> encoder = (id<MTLComputeCommandEncoder>)stream->commandEncoder();[encoder setComputePipelineState:pipeline];
|
| 138 |
+
|
| 139 |
+
auto set_tensor = [&](const at::Tensor& t, int idx) {[encoder setBuffer:storage_as_mtlbuffer(t) offset:storage_offset_bytes(t) atIndex:(NSUInteger)idx];
|
| 140 |
+
};
|
| 141 |
+
set_tensor(gy2d, 0);
|
| 142 |
+
set_tensor(weight, 1);
|
| 143 |
+
set_tensor(active_mask, 2);
|
| 144 |
+
set_tensor(grad_x, 3);
|
| 145 |
+
|
| 146 |
+
SparseLinearParams prm{(uint32_t)N, (uint32_t)In, (uint32_t)Out, 0};[encoder setBytes:&prm length:sizeof(SparseLinearParams) atIndex:4];
|
| 147 |
+
|
| 148 |
+
MTLSize tg = MTLSizeMake(16, 16, 1);
|
| 149 |
+
MTLSize grid = MTLSizeMake((NSUInteger)((In + 15) / 16), (NSUInteger)((N + 15) / 16), 1);
|
| 150 |
+
[encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg];
|
| 151 |
+
|
| 152 |
+
return grad_x;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 156 |
+
m.def("sparse_linear_grad_wb", &sparse_linear_grad_wb, "Sparse active-row Linear dW/db (Metal/MPS, fp32)");
|
| 157 |
+
m.def("sparse_linear_grad_x", &sparse_linear_grad_x, "Sparse active-row Linear dX (Metal/MPS, fp32)");
|
| 158 |
+
}
|
experiments/sparse_linear_v10_metal/sparse_transformer_v10.py
ADDED
|
@@ -0,0 +1,1112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sparse Transformer v10: empirical sparse-backward optimization benchmark with optional Metal kernels.
|
| 3 |
+
|
| 4 |
+
v8 proved that the row-sparse mask can be moved into a custom Linear backward.
|
| 5 |
+
v10 keeps the no-audit training path and adds wall-clock benchmarking plus an optional Metal active-row Linear backward kernel.
|
| 6 |
+
|
| 7 |
+
Default behavior
|
| 8 |
+
----------------
|
| 9 |
+
1. Run a short dense warmup, usually 5 steps.
|
| 10 |
+
2. Initialize the EMA row-importance predictor from those dense warmup gradients.
|
| 11 |
+
3. After warmup, choose active rows from the predictor.
|
| 12 |
+
4. Train using sparse backward.
|
| 13 |
+
5. Update EMA statistics only from rows that were actually active/observed.
|
| 14 |
+
6. Do not compute dense gradients unless --audit_every > 0.
|
| 15 |
+
|
| 16 |
+
Audit behavior
|
| 17 |
+
--------------
|
| 18 |
+
--audit_every 0
|
| 19 |
+
No dense audit after warmup. Cosine/Jaccard/top20 are unavailable and show as nan.
|
| 20 |
+
|
| 21 |
+
--audit_every N
|
| 22 |
+
Every N steps, run an extra dense backward pass on the same batch only to
|
| 23 |
+
measure cosine/top20/Jaccard. The audit is NOT used to update the selector,
|
| 24 |
+
except for oracle_current, which is explicitly an upper-bound control.
|
| 25 |
+
|
| 26 |
+
This is still not a wall-clock benchmark on vanilla PyTorch/MPS/CPU. The custom
|
| 27 |
+
backward uses indexing and ordinary PyTorch matmuls. The goal is to verify that
|
| 28 |
+
the method survives without dense information after warmup.
|
| 29 |
+
|
| 30 |
+
Examples
|
| 31 |
+
--------
|
| 32 |
+
No-audit practical run:
|
| 33 |
+
python3 sparse_transformer_v9.py \
|
| 34 |
+
--device mps \
|
| 35 |
+
--steps 2000 \
|
| 36 |
+
--active_fractions 0.05 0.02 \
|
| 37 |
+
--warmup_steps_list 5 \
|
| 38 |
+
--policies predicted_magnitude random \
|
| 39 |
+
--backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
|
| 40 |
+
--audit_every 0
|
| 41 |
+
|
| 42 |
+
Occasional audit for measurement only:
|
| 43 |
+
python3 sparse_transformer_v9.py \
|
| 44 |
+
--steps 2000 \
|
| 45 |
+
--active_fractions 0.05 0.02 \
|
| 46 |
+
--warmup_steps_list 5 \
|
| 47 |
+
--policies predicted_magnitude random \
|
| 48 |
+
--backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
|
| 49 |
+
--audit_every 100
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
from __future__ import annotations
|
| 54 |
+
|
| 55 |
+
import argparse
|
| 56 |
+
import math
|
| 57 |
+
import random
|
| 58 |
+
import time
|
| 59 |
+
from typing import Dict, List, Literal, Optional, Tuple
|
| 60 |
+
|
| 61 |
+
import torch
|
| 62 |
+
|
| 63 |
+
torch.set_num_threads(1)
|
| 64 |
+
import torch.nn as nn
|
| 65 |
+
import torch.nn.functional as F
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
import sparse_linear_metal # built by setup.py in this folder
|
| 69 |
+
except Exception:
|
| 70 |
+
sparse_linear_metal = None
|
| 71 |
+
|
| 72 |
+
USE_METAL_KERNEL = False
|
| 73 |
+
|
| 74 |
+
def sync_device(device: str) -> None:
|
| 75 |
+
if device == "cuda" and torch.cuda.is_available():
|
| 76 |
+
torch.cuda.synchronize()
|
| 77 |
+
elif device == "mps" and hasattr(torch, "mps"):
|
| 78 |
+
torch.mps.synchronize()
|
| 79 |
+
|
| 80 |
+
Policy = Literal["predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"]
|
| 81 |
+
BackwardMode = Literal["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# -----------------------------
|
| 85 |
+
# Reproducibility and device
|
| 86 |
+
# -----------------------------
|
| 87 |
+
|
| 88 |
+
def set_seed(seed: int) -> None:
|
| 89 |
+
random.seed(seed)
|
| 90 |
+
torch.manual_seed(seed)
|
| 91 |
+
if torch.cuda.is_available():
|
| 92 |
+
torch.cuda.manual_seed_all(seed)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def default_device() -> str:
|
| 96 |
+
if torch.cuda.is_available():
|
| 97 |
+
return "cuda"
|
| 98 |
+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 99 |
+
return "mps"
|
| 100 |
+
return "cpu"
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def make_cpu_generator(seed: int) -> torch.Generator:
|
| 104 |
+
gen = torch.Generator(device="cpu")
|
| 105 |
+
gen.manual_seed(seed)
|
| 106 |
+
return gen
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# -----------------------------
|
| 110 |
+
# Data
|
| 111 |
+
# -----------------------------
|
| 112 |
+
|
| 113 |
+
def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
|
| 114 |
+
rng = random.Random(seed)
|
| 115 |
+
names = ["ada", "turing", "grace", "lovelace", "noether", "shannon", "hopper", "gauss"]
|
| 116 |
+
verbs = ["builds", "tests", "traces", "compresses", "predicts", "routes", "writes", "measures"]
|
| 117 |
+
objects = ["signals", "gradients", "tokens", "circuits", "features", "masks", "errors", "states"]
|
| 118 |
+
adverbs = ["quietly", "boldly", "slowly", "quickly", "cleanly", "strangely", "carefully"]
|
| 119 |
+
clauses = [
|
| 120 |
+
"when the loss falls",
|
| 121 |
+
"after the mask shifts",
|
| 122 |
+
"before the model answers",
|
| 123 |
+
"while the signal drifts",
|
| 124 |
+
"if the pattern repeats",
|
| 125 |
+
"because the tail is noisy",
|
| 126 |
+
]
|
| 127 |
+
symbols = ["alpha", "beta", "gamma", "delta", "omega", "sigma"]
|
| 128 |
+
|
| 129 |
+
lines: List[str] = []
|
| 130 |
+
for _ in range(n_sentences):
|
| 131 |
+
t = rng.randrange(6)
|
| 132 |
+
if t == 0:
|
| 133 |
+
line = f"{rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} {rng.choice(adverbs)}."
|
| 134 |
+
elif t == 1:
|
| 135 |
+
line = f"{rng.choice(clauses)}, {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)}."
|
| 136 |
+
elif t == 2:
|
| 137 |
+
a, b = rng.sample(symbols, 2)
|
| 138 |
+
line = f"rule {a}: {rng.choice(objects)} -> {rng.choice(objects)}; rule {b}: {rng.choice(objects)} -> {rng.choice(objects)}."
|
| 139 |
+
elif t == 3:
|
| 140 |
+
line = f"the {rng.choice(objects)} {rng.choice(verbs)} the {rng.choice(objects)} {rng.choice(adverbs)}."
|
| 141 |
+
elif t == 4:
|
| 142 |
+
seq = " ".join(rng.choice(symbols) for _ in range(rng.randint(2, 7)))
|
| 143 |
+
line = f"sequence {seq} ends when {rng.choice(names)} {rng.choice(verbs)}."
|
| 144 |
+
else:
|
| 145 |
+
line = f"if {rng.choice(objects)} rise then {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} else wait."
|
| 146 |
+
lines.append(line)
|
| 147 |
+
return "\n".join(lines) + "\n"
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class CharCorpus:
|
| 151 |
+
def __init__(self, text: str, block_size: int, device: str):
|
| 152 |
+
chars = sorted(set(text))
|
| 153 |
+
self.stoi = {ch: i for i, ch in enumerate(chars)}
|
| 154 |
+
self.itos = {i: ch for ch, i in self.stoi.items()}
|
| 155 |
+
self.vocab_size = len(chars)
|
| 156 |
+
self.block_size = block_size
|
| 157 |
+
self.device = device
|
| 158 |
+
|
| 159 |
+
data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
|
| 160 |
+
split = int(0.9 * len(data))
|
| 161 |
+
self.train_data = data[:split]
|
| 162 |
+
self.val_data = data[split:]
|
| 163 |
+
|
| 164 |
+
def get_batch(
|
| 165 |
+
self,
|
| 166 |
+
split: str,
|
| 167 |
+
batch_size: int,
|
| 168 |
+
generator: Optional[torch.Generator] = None,
|
| 169 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 170 |
+
data = self.train_data if split == "train" else self.val_data
|
| 171 |
+
max_start = len(data) - self.block_size - 1
|
| 172 |
+
if max_start <= 0:
|
| 173 |
+
raise ValueError("Corpus too small for block_size")
|
| 174 |
+
ix = torch.randint(max_start, (batch_size,), generator=generator)
|
| 175 |
+
x = torch.stack([data[i : i + self.block_size] for i in ix])
|
| 176 |
+
y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
|
| 177 |
+
return x.to(self.device), y.to(self.device)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def load_text(args: argparse.Namespace) -> str:
|
| 181 |
+
if args.text_path:
|
| 182 |
+
with open(args.text_path, "r", encoding="utf-8") as f:
|
| 183 |
+
return f.read()
|
| 184 |
+
return make_synthetic_corpus(args.synthetic_sentences, args.seed)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# -----------------------------
|
| 188 |
+
# Sparse Linear autograd
|
| 189 |
+
# -----------------------------
|
| 190 |
+
|
| 191 |
+
class MaskedLinearFunction(torch.autograd.Function):
|
| 192 |
+
@staticmethod
|
| 193 |
+
def forward( # type: ignore[override]
|
| 194 |
+
ctx,
|
| 195 |
+
x: torch.Tensor,
|
| 196 |
+
weight: torch.Tensor,
|
| 197 |
+
bias: Optional[torch.Tensor],
|
| 198 |
+
active_rows: torch.Tensor,
|
| 199 |
+
sparse_dx: bool,
|
| 200 |
+
) -> torch.Tensor:
|
| 201 |
+
ctx.save_for_backward(x, weight, active_rows)
|
| 202 |
+
ctx.has_bias = bias is not None
|
| 203 |
+
ctx.sparse_dx = bool(sparse_dx)
|
| 204 |
+
return F.linear(x, weight, bias)
|
| 205 |
+
|
| 206 |
+
@staticmethod
|
| 207 |
+
def backward(ctx, grad_y: torch.Tensor): # type: ignore[override]
|
| 208 |
+
x, weight, active_rows = ctx.saved_tensors
|
| 209 |
+
sparse_dx = bool(ctx.sparse_dx)
|
| 210 |
+
has_bias = bool(ctx.has_bias)
|
| 211 |
+
|
| 212 |
+
x_shape = x.shape
|
| 213 |
+
x_flat = x.reshape(-1, x.shape[-1]).contiguous()
|
| 214 |
+
gy_flat = grad_y.reshape(-1, grad_y.shape[-1]).contiguous()
|
| 215 |
+
|
| 216 |
+
active_idx = torch.nonzero(active_rows, as_tuple=False).flatten().to(dtype=torch.long).contiguous()
|
| 217 |
+
|
| 218 |
+
can_use_metal = (
|
| 219 |
+
USE_METAL_KERNEL
|
| 220 |
+
and sparse_linear_metal is not None
|
| 221 |
+
and x.device.type == "mps"
|
| 222 |
+
and weight.device.type == "mps"
|
| 223 |
+
and x.dtype == torch.float32
|
| 224 |
+
and weight.dtype == torch.float32
|
| 225 |
+
and gy_flat.dtype == torch.float32
|
| 226 |
+
and x_flat.is_contiguous()
|
| 227 |
+
and gy_flat.is_contiguous()
|
| 228 |
+
and weight.is_contiguous()
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
if can_use_metal:
|
| 232 |
+
grad_weight, grad_bias_full = sparse_linear_metal.sparse_linear_grad_wb(
|
| 233 |
+
x_flat, gy_flat, active_idx, int(weight.shape[0])
|
| 234 |
+
)
|
| 235 |
+
grad_bias = grad_bias_full if has_bias else None
|
| 236 |
+
if sparse_dx:
|
| 237 |
+
grad_x_flat = sparse_linear_metal.sparse_linear_grad_x(gy_flat, weight.contiguous(), active_idx)
|
| 238 |
+
else:
|
| 239 |
+
grad_x_flat = gy_flat @ weight
|
| 240 |
+
grad_x = grad_x_flat.reshape(x_shape)
|
| 241 |
+
return grad_x, grad_weight, grad_bias, None, None
|
| 242 |
+
|
| 243 |
+
# Portable PyTorch fallback. This is mathematically identical but usually
|
| 244 |
+
# not an actual wall-clock optimization on MPS/CPU because indexing and
|
| 245 |
+
# general matmul overhead dominate.
|
| 246 |
+
grad_weight = torch.zeros_like(weight)
|
| 247 |
+
grad_bias = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if has_bias else None
|
| 248 |
+
|
| 249 |
+
if active_idx.numel() > 0:
|
| 250 |
+
gy_active = gy_flat[:, active_idx]
|
| 251 |
+
grad_weight[active_idx] = gy_active.transpose(0, 1) @ x_flat
|
| 252 |
+
if grad_bias is not None:
|
| 253 |
+
grad_bias[active_idx] = gy_active.sum(dim=0)
|
| 254 |
+
|
| 255 |
+
if sparse_dx:
|
| 256 |
+
grad_x_flat = gy_active @ weight[active_idx]
|
| 257 |
+
else:
|
| 258 |
+
grad_x_flat = gy_flat @ weight
|
| 259 |
+
else:
|
| 260 |
+
if sparse_dx:
|
| 261 |
+
grad_x_flat = torch.zeros_like(x_flat)
|
| 262 |
+
else:
|
| 263 |
+
grad_x_flat = gy_flat @ weight
|
| 264 |
+
|
| 265 |
+
grad_x = grad_x_flat.reshape(x_shape)
|
| 266 |
+
return grad_x, grad_weight, grad_bias, None, None
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class SparseLinear(nn.Linear):
|
| 270 |
+
"""nn.Linear with an optional row-sparse backward pass."""
|
| 271 |
+
|
| 272 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
| 273 |
+
super().__init__(in_features, out_features, bias=bias)
|
| 274 |
+
self.sparse_enabled = False
|
| 275 |
+
self.sparse_dx = False
|
| 276 |
+
self.active_rows: Optional[torch.Tensor] = None
|
| 277 |
+
|
| 278 |
+
def set_sparse_backward(self, enabled: bool, active_rows: Optional[torch.Tensor], sparse_dx: bool) -> None:
|
| 279 |
+
self.sparse_enabled = bool(enabled)
|
| 280 |
+
self.sparse_dx = bool(sparse_dx)
|
| 281 |
+
self.active_rows = active_rows
|
| 282 |
+
|
| 283 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 284 |
+
if not self.sparse_enabled or self.active_rows is None:
|
| 285 |
+
return F.linear(x, self.weight, self.bias)
|
| 286 |
+
return MaskedLinearFunction.apply(x, self.weight, self.bias, self.active_rows, self.sparse_dx)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
# -----------------------------
|
| 290 |
+
# Mini GPT
|
| 291 |
+
# -----------------------------
|
| 292 |
+
|
| 293 |
+
class CausalSelfAttention(nn.Module):
|
| 294 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 295 |
+
super().__init__()
|
| 296 |
+
assert n_embd % n_head == 0
|
| 297 |
+
self.n_head = n_head
|
| 298 |
+
self.head_dim = n_embd // n_head
|
| 299 |
+
self.c_attn = SparseLinear(n_embd, 3 * n_embd)
|
| 300 |
+
self.c_proj = SparseLinear(n_embd, n_embd)
|
| 301 |
+
self.dropout = nn.Dropout(dropout)
|
| 302 |
+
self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
|
| 303 |
+
|
| 304 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 305 |
+
B, T, C = x.shape
|
| 306 |
+
qkv = self.c_attn(x)
|
| 307 |
+
q, k, v = qkv.split(C, dim=2)
|
| 308 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 309 |
+
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 310 |
+
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 311 |
+
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 312 |
+
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
|
| 313 |
+
att = F.softmax(att, dim=-1)
|
| 314 |
+
att = self.dropout(att)
|
| 315 |
+
y = att @ v
|
| 316 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 317 |
+
return self.c_proj(y)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class FeedForward(nn.Module):
|
| 321 |
+
def __init__(self, n_embd: int, dropout: float):
|
| 322 |
+
super().__init__()
|
| 323 |
+
self.c_fc = SparseLinear(n_embd, 4 * n_embd)
|
| 324 |
+
self.c_proj = SparseLinear(4 * n_embd, n_embd)
|
| 325 |
+
self.dropout = nn.Dropout(dropout)
|
| 326 |
+
|
| 327 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 328 |
+
return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
class Block(nn.Module):
|
| 332 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 333 |
+
super().__init__()
|
| 334 |
+
self.ln1 = nn.LayerNorm(n_embd)
|
| 335 |
+
self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
|
| 336 |
+
self.ln2 = nn.LayerNorm(n_embd)
|
| 337 |
+
self.mlp = FeedForward(n_embd, dropout)
|
| 338 |
+
|
| 339 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 340 |
+
x = x + self.attn(self.ln1(x))
|
| 341 |
+
x = x + self.mlp(self.ln2(x))
|
| 342 |
+
return x
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
class MiniGPT(nn.Module):
|
| 346 |
+
def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float):
|
| 347 |
+
super().__init__()
|
| 348 |
+
self.block_size = block_size
|
| 349 |
+
self.tok_emb = nn.Embedding(vocab_size, n_embd)
|
| 350 |
+
self.pos_emb = nn.Embedding(block_size, n_embd)
|
| 351 |
+
self.drop = nn.Dropout(dropout)
|
| 352 |
+
self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
|
| 353 |
+
self.ln_f = nn.LayerNorm(n_embd)
|
| 354 |
+
self.lm_head = SparseLinear(n_embd, vocab_size)
|
| 355 |
+
|
| 356 |
+
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
|
| 357 |
+
B, T = idx.shape
|
| 358 |
+
pos = torch.arange(T, device=idx.device)
|
| 359 |
+
x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
|
| 360 |
+
x = self.drop(x)
|
| 361 |
+
x = self.blocks(x)
|
| 362 |
+
x = self.ln_f(x)
|
| 363 |
+
logits = self.lm_head(x)
|
| 364 |
+
loss = None
|
| 365 |
+
if targets is not None:
|
| 366 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
| 367 |
+
return logits, loss
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def named_sparse_linear_modules(model: nn.Module) -> List[Tuple[str, SparseLinear]]:
|
| 371 |
+
return [(name, m) for name, m in model.named_modules() if isinstance(m, SparseLinear)]
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def parameter_fractions(model: nn.Module) -> Tuple[int, int, float]:
|
| 375 |
+
total = sum(p.numel() for p in model.parameters())
|
| 376 |
+
linear = 0
|
| 377 |
+
for _, m in named_sparse_linear_modules(model):
|
| 378 |
+
linear += m.weight.numel()
|
| 379 |
+
if m.bias is not None:
|
| 380 |
+
linear += m.bias.numel()
|
| 381 |
+
return total, linear, linear / max(1, total)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def configure_sparse_linears(
|
| 385 |
+
model: nn.Module,
|
| 386 |
+
masker: Optional["RowMasker"],
|
| 387 |
+
enabled: bool,
|
| 388 |
+
backward_mode: Optional[str],
|
| 389 |
+
) -> None:
|
| 390 |
+
sparse_dx = backward_mode == "sparse_dW_sparse_dX"
|
| 391 |
+
for _, m in named_sparse_linear_modules(model):
|
| 392 |
+
active = masker.row_mask_for(m) if masker is not None else None
|
| 393 |
+
m.set_sparse_backward(enabled=enabled, active_rows=active, sparse_dx=sparse_dx)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
# -----------------------------
|
| 397 |
+
# Mask selector
|
| 398 |
+
# -----------------------------
|
| 399 |
+
|
| 400 |
+
class RowMasker:
|
| 401 |
+
def __init__(
|
| 402 |
+
self,
|
| 403 |
+
model: nn.Module,
|
| 404 |
+
policy: Policy,
|
| 405 |
+
active_fraction: float,
|
| 406 |
+
explore_fraction: float,
|
| 407 |
+
mass_beta: float,
|
| 408 |
+
unobserved_decay: float,
|
| 409 |
+
warmup_steps: int,
|
| 410 |
+
ucb_alpha: float,
|
| 411 |
+
mass_init: float,
|
| 412 |
+
device: str,
|
| 413 |
+
):
|
| 414 |
+
self.model = model
|
| 415 |
+
self.policy = policy
|
| 416 |
+
self.active_fraction = active_fraction
|
| 417 |
+
self.explore_fraction = explore_fraction
|
| 418 |
+
self.mass_beta = mass_beta
|
| 419 |
+
self.unobserved_decay = unobserved_decay
|
| 420 |
+
self.warmup_steps = warmup_steps
|
| 421 |
+
self.ucb_alpha = ucb_alpha
|
| 422 |
+
self.mass_init = mass_init
|
| 423 |
+
self.device = device
|
| 424 |
+
self.step_index = 0
|
| 425 |
+
|
| 426 |
+
self.linear_modules = [m for _, m in named_sparse_linear_modules(model)]
|
| 427 |
+
self.module_to_ids: Dict[SparseLinear, torch.Tensor] = {}
|
| 428 |
+
ids = []
|
| 429 |
+
offset = 0
|
| 430 |
+
for m in self.linear_modules:
|
| 431 |
+
n = m.weight.shape[0]
|
| 432 |
+
block_ids = torch.arange(offset, offset + n, device=device)
|
| 433 |
+
self.module_to_ids[m] = block_ids
|
| 434 |
+
ids.append(block_ids)
|
| 435 |
+
offset += n
|
| 436 |
+
self.n_blocks = offset
|
| 437 |
+
|
| 438 |
+
self.predicted_mass = torch.full((self.n_blocks,), mass_init, device=device)
|
| 439 |
+
self.last_full_mass = torch.full((self.n_blocks,), mass_init, device=device)
|
| 440 |
+
self.observed_count = torch.zeros(self.n_blocks, device=device)
|
| 441 |
+
self.global_mass_ema = torch.tensor(max(mass_init, 1e-6), device=device)
|
| 442 |
+
|
| 443 |
+
self.prev_active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device)
|
| 444 |
+
self.active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device)
|
| 445 |
+
self.row_masks: Dict[SparseLinear, torch.Tensor] = {
|
| 446 |
+
m: torch.zeros(m.weight.shape[0], dtype=torch.bool, device=device) for m in self.linear_modules
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
def _topk_mask(self, values: torch.Tensor, fraction: float) -> torch.Tensor:
|
| 450 |
+
k = max(1, int(fraction * values.numel()))
|
| 451 |
+
mask = torch.zeros_like(values, dtype=torch.bool)
|
| 452 |
+
noisy = values + 1e-9 * torch.rand_like(values)
|
| 453 |
+
mask[torch.topk(noisy, k=k).indices] = True
|
| 454 |
+
return mask
|
| 455 |
+
|
| 456 |
+
@staticmethod
|
| 457 |
+
def _jaccard(a: torch.Tensor, b: torch.Tensor) -> float:
|
| 458 |
+
inter = (a & b).sum().float()
|
| 459 |
+
union = (a | b).sum().float()
|
| 460 |
+
return float((inter / torch.clamp(union, min=1.0)).item())
|
| 461 |
+
|
| 462 |
+
def _set_active(self, active: torch.Tensor) -> None:
|
| 463 |
+
self.active = active
|
| 464 |
+
self.row_masks = {}
|
| 465 |
+
for m, ids in self.module_to_ids.items():
|
| 466 |
+
self.row_masks[m] = active[ids]
|
| 467 |
+
|
| 468 |
+
def _sample_exploit_explore(self, scores: torch.Tensor) -> torch.Tensor:
|
| 469 |
+
n = self.n_blocks
|
| 470 |
+
k_total = max(1, int(self.active_fraction * n))
|
| 471 |
+
k_explore = min(k_total, max(0, int(self.explore_fraction * k_total)))
|
| 472 |
+
k_exploit = k_total - k_explore
|
| 473 |
+
active = torch.zeros(n, dtype=torch.bool, device=self.device)
|
| 474 |
+
|
| 475 |
+
if k_exploit > 0:
|
| 476 |
+
active[torch.topk(scores + 1e-9 * torch.rand_like(scores), k=k_exploit).indices] = True
|
| 477 |
+
if k_explore > 0:
|
| 478 |
+
remaining = torch.nonzero(~active, as_tuple=False).flatten()
|
| 479 |
+
pick = remaining[torch.randperm(remaining.numel(), device=self.device)[:k_explore]]
|
| 480 |
+
active[pick] = True
|
| 481 |
+
return active
|
| 482 |
+
|
| 483 |
+
def choose_pre_backward(self, step: int) -> None:
|
| 484 |
+
self.step_index = step
|
| 485 |
+
if step < self.warmup_steps:
|
| 486 |
+
self._set_active(torch.ones(self.n_blocks, dtype=torch.bool, device=self.device))
|
| 487 |
+
return
|
| 488 |
+
|
| 489 |
+
if self.policy == "oracle_current":
|
| 490 |
+
# Oracle cannot choose until the dense audit gradient is known.
|
| 491 |
+
self._set_active(torch.zeros(self.n_blocks, dtype=torch.bool, device=self.device))
|
| 492 |
+
return
|
| 493 |
+
|
| 494 |
+
if self.policy == "random":
|
| 495 |
+
self._set_active(self._sample_exploit_explore(torch.rand(self.n_blocks, device=self.device)))
|
| 496 |
+
return
|
| 497 |
+
|
| 498 |
+
if self.policy == "stale_current":
|
| 499 |
+
self._set_active(self._topk_mask(self.last_full_mass, self.active_fraction))
|
| 500 |
+
return
|
| 501 |
+
|
| 502 |
+
if self.policy == "predicted_magnitude":
|
| 503 |
+
self._set_active(self._sample_exploit_explore(self.predicted_mass))
|
| 504 |
+
return
|
| 505 |
+
|
| 506 |
+
if self.policy == "ucb_magnitude":
|
| 507 |
+
t = max(1, step - self.warmup_steps + 1)
|
| 508 |
+
log_term = torch.log(torch.tensor(float(t + 2), device=self.device))
|
| 509 |
+
bonus_scale = torch.clamp(self.global_mass_ema, min=1e-8)
|
| 510 |
+
bonus = self.ucb_alpha * bonus_scale * torch.sqrt(log_term / (self.observed_count + 1.0))
|
| 511 |
+
self._set_active(self._sample_exploit_explore(self.predicted_mass + bonus))
|
| 512 |
+
return
|
| 513 |
+
|
| 514 |
+
raise ValueError(f"Unknown policy: {self.policy}")
|
| 515 |
+
|
| 516 |
+
@torch.no_grad()
|
| 517 |
+
def current_gradient_mass_from_grads(self) -> torch.Tensor:
|
| 518 |
+
mass = torch.zeros(self.n_blocks, device=self.device)
|
| 519 |
+
for m, ids in self.module_to_ids.items():
|
| 520 |
+
if m.weight.grad is None:
|
| 521 |
+
continue
|
| 522 |
+
row_sq = m.weight.grad.square().sum(dim=1)
|
| 523 |
+
if m.bias is not None and m.bias.grad is not None:
|
| 524 |
+
row_sq = row_sq + m.bias.grad.square()
|
| 525 |
+
mass[ids] = torch.sqrt(row_sq + 1e-30)
|
| 526 |
+
return mass
|
| 527 |
+
|
| 528 |
+
@torch.no_grad()
|
| 529 |
+
@torch.no_grad()
|
| 530 |
+
def update_predictor_from_observed_mass(self, mass: torch.Tensor, observed: Optional[torch.Tensor] = None) -> Dict[str, float]:
|
| 531 |
+
"""Update EMA statistics only for observed rows.
|
| 532 |
+
|
| 533 |
+
After warmup, sparse backward only gives trustworthy gradients for active
|
| 534 |
+
rows, so only those rows are allowed to update predicted_mass.
|
| 535 |
+
"""
|
| 536 |
+
if observed is None:
|
| 537 |
+
observed = self.active
|
| 538 |
+
|
| 539 |
+
new_active = observed & (self.observed_count == 0)
|
| 540 |
+
self.predicted_mass.mul_(self.unobserved_decay)
|
| 541 |
+
|
| 542 |
+
if bool(observed.any().item()):
|
| 543 |
+
obs_mass = mass[observed]
|
| 544 |
+
first_seen = self.observed_count[observed] == 0
|
| 545 |
+
ema_mass = self.mass_beta * self.predicted_mass[observed] + (1.0 - self.mass_beta) * obs_mass
|
| 546 |
+
self.predicted_mass[observed] = torch.where(first_seen, obs_mass, ema_mass)
|
| 547 |
+
self.observed_count[observed] += 1.0
|
| 548 |
+
self.global_mass_ema = self.mass_beta * self.global_mass_ema + (1.0 - self.mass_beta) * obs_mass.mean()
|
| 549 |
+
|
| 550 |
+
stability = self._jaccard(self.active, self.prev_active)
|
| 551 |
+
self.prev_active = self.active.clone()
|
| 552 |
+
|
| 553 |
+
return {
|
| 554 |
+
"stability": stability,
|
| 555 |
+
"active_fraction_real": float(self.active.float().mean().item()),
|
| 556 |
+
"coverage": float((self.observed_count > 0).float().mean().item()),
|
| 557 |
+
"avg_obs_count": float(self.observed_count.mean().item()),
|
| 558 |
+
"new_active_fraction": float(new_active.float().mean().item()),
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
@torch.no_grad()
|
| 562 |
+
def audit_metrics_from_mass(self, mass: torch.Tensor) -> Dict[str, float]:
|
| 563 |
+
"""Compute dense-audit metrics without updating the practical selector."""
|
| 564 |
+
active = self.active
|
| 565 |
+
true_sq = mass.square().sum()
|
| 566 |
+
approx_sq = mass[active].square().sum()
|
| 567 |
+
cosine = float((torch.sqrt(approx_sq + 1e-30) / torch.sqrt(true_sq + 1e-30)).item())
|
| 568 |
+
|
| 569 |
+
oracle_mask = self._topk_mask(mass, self.active_fraction)
|
| 570 |
+
jacc = self._jaccard(active, oracle_mask)
|
| 571 |
+
|
| 572 |
+
k20 = max(1, int(0.2 * self.n_blocks))
|
| 573 |
+
sorted_mass = torch.sort(mass, descending=True).values
|
| 574 |
+
top20_mass = float((sorted_mass[:k20].sum() / (sorted_mass.sum() + 1e-12)).item())
|
| 575 |
+
|
| 576 |
+
return {
|
| 577 |
+
"cosine": cosine,
|
| 578 |
+
"norm_ratio": cosine,
|
| 579 |
+
"top20_mass": top20_mass,
|
| 580 |
+
"jacc_oracle": jacc,
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
def audit_and_update_from_mass(self, step: int, mass: torch.Tensor) -> Dict[str, float]:
|
| 584 |
+
if step < self.warmup_steps:
|
| 585 |
+
active = torch.ones(self.n_blocks, dtype=torch.bool, device=self.device)
|
| 586 |
+
self._set_active(active)
|
| 587 |
+
elif self.policy == "oracle_current":
|
| 588 |
+
active = self._topk_mask(mass, self.active_fraction)
|
| 589 |
+
self._set_active(active)
|
| 590 |
+
else:
|
| 591 |
+
active = self.active
|
| 592 |
+
|
| 593 |
+
true_sq = mass.square().sum()
|
| 594 |
+
approx_sq = mass[active].square().sum()
|
| 595 |
+
cosine = float((torch.sqrt(approx_sq + 1e-30) / torch.sqrt(true_sq + 1e-30)).item())
|
| 596 |
+
|
| 597 |
+
oracle_mask = self._topk_mask(mass, self.active_fraction)
|
| 598 |
+
jacc = self._jaccard(active, oracle_mask)
|
| 599 |
+
stability = self._jaccard(active, self.prev_active)
|
| 600 |
+
self.prev_active = active.clone()
|
| 601 |
+
|
| 602 |
+
k20 = max(1, int(0.2 * self.n_blocks))
|
| 603 |
+
sorted_mass = torch.sort(mass, descending=True).values
|
| 604 |
+
top20_mass = float((sorted_mass[:k20].sum() / (sorted_mass.sum() + 1e-12)).item())
|
| 605 |
+
|
| 606 |
+
new_active = active & (self.observed_count == 0)
|
| 607 |
+
|
| 608 |
+
# Practical rule: update predicted statistics only for active/observed rows.
|
| 609 |
+
self.predicted_mass.mul_(self.unobserved_decay)
|
| 610 |
+
observed = active
|
| 611 |
+
if bool(observed.any().item()):
|
| 612 |
+
obs_mass = mass[observed]
|
| 613 |
+
first_seen = self.observed_count[observed] == 0
|
| 614 |
+
ema_mass = self.mass_beta * self.predicted_mass[observed] + (1.0 - self.mass_beta) * obs_mass
|
| 615 |
+
self.predicted_mass[observed] = torch.where(first_seen, obs_mass, ema_mass)
|
| 616 |
+
self.observed_count[observed] += 1.0
|
| 617 |
+
self.global_mass_ema = self.mass_beta * self.global_mass_ema + (1.0 - self.mass_beta) * obs_mass.mean()
|
| 618 |
+
|
| 619 |
+
# Dense audit signal; only stale_current is allowed to use this for selection.
|
| 620 |
+
self.last_full_mass = mass.detach().clone()
|
| 621 |
+
|
| 622 |
+
return {
|
| 623 |
+
"cosine": cosine,
|
| 624 |
+
"norm_ratio": cosine,
|
| 625 |
+
"top20_mass": top20_mass,
|
| 626 |
+
"jacc_oracle": jacc,
|
| 627 |
+
"stability": stability,
|
| 628 |
+
"active_fraction_real": float(active.float().mean().item()),
|
| 629 |
+
"coverage": float((self.observed_count > 0).float().mean().item()),
|
| 630 |
+
"avg_obs_count": float(self.observed_count.mean().item()),
|
| 631 |
+
"new_active_fraction": float(new_active.float().mean().item()),
|
| 632 |
+
}
|
| 633 |
+
|
| 634 |
+
def row_mask_for(self, module: SparseLinear) -> Optional[torch.Tensor]:
|
| 635 |
+
return self.row_masks.get(module)
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
# -----------------------------
|
| 639 |
+
# Masked Adam
|
| 640 |
+
# -----------------------------
|
| 641 |
+
|
| 642 |
+
class MaskedAdam:
|
| 643 |
+
def __init__(
|
| 644 |
+
self,
|
| 645 |
+
model: nn.Module,
|
| 646 |
+
masker: Optional[RowMasker],
|
| 647 |
+
lr: float,
|
| 648 |
+
betas=(0.9, 0.95),
|
| 649 |
+
eps=1e-8,
|
| 650 |
+
weight_decay=0.0,
|
| 651 |
+
freeze_non_linear_when_sparse: bool = False,
|
| 652 |
+
):
|
| 653 |
+
self.model = model
|
| 654 |
+
self.masker = masker
|
| 655 |
+
self.lr = lr
|
| 656 |
+
self.beta1, self.beta2 = betas
|
| 657 |
+
self.eps = eps
|
| 658 |
+
self.weight_decay = weight_decay
|
| 659 |
+
self.freeze_non_linear_when_sparse = freeze_non_linear_when_sparse
|
| 660 |
+
self.state: Dict[nn.Parameter, Dict[str, torch.Tensor]] = {}
|
| 661 |
+
self.linear_param: Dict[nn.Parameter, Tuple[SparseLinear, str]] = {}
|
| 662 |
+
for _, m in named_sparse_linear_modules(model):
|
| 663 |
+
self.linear_param[m.weight] = (m, "weight")
|
| 664 |
+
if m.bias is not None:
|
| 665 |
+
self.linear_param[m.bias] = (m, "bias")
|
| 666 |
+
|
| 667 |
+
def zero_grad(self) -> None:
|
| 668 |
+
for p in self.model.parameters():
|
| 669 |
+
p.grad = None
|
| 670 |
+
|
| 671 |
+
@torch.no_grad()
|
| 672 |
+
def step(self) -> None:
|
| 673 |
+
for p in self.model.parameters():
|
| 674 |
+
if p.grad is None:
|
| 675 |
+
continue
|
| 676 |
+
if self.masker is not None and self.freeze_non_linear_when_sparse and p not in self.linear_param:
|
| 677 |
+
continue
|
| 678 |
+
|
| 679 |
+
if p not in self.state:
|
| 680 |
+
self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
|
| 681 |
+
m = self.state[p]["m"]
|
| 682 |
+
v = self.state[p]["v"]
|
| 683 |
+
g = p.grad
|
| 684 |
+
if self.weight_decay:
|
| 685 |
+
g = g.add(p, alpha=self.weight_decay)
|
| 686 |
+
|
| 687 |
+
row_mask = None
|
| 688 |
+
if self.masker is not None and p in self.linear_param:
|
| 689 |
+
module, kind = self.linear_param[p]
|
| 690 |
+
base = self.masker.row_mask_for(module)
|
| 691 |
+
if base is not None:
|
| 692 |
+
row_mask = base.view(-1, *([1] * (p.ndim - 1))) if kind == "weight" else base
|
| 693 |
+
|
| 694 |
+
if row_mask is None:
|
| 695 |
+
m.mul_(self.beta1).add_(g, alpha=1.0 - self.beta1)
|
| 696 |
+
v.mul_(self.beta2).addcmul_(g, g, value=1.0 - self.beta2)
|
| 697 |
+
p.add_(m / (torch.sqrt(v) + self.eps), alpha=-self.lr)
|
| 698 |
+
else:
|
| 699 |
+
# MPS can mis-handle expanded boolean masks for row-wise assignment
|
| 700 |
+
# (e.g. reporting nonsense out-of-bounds indices). Use explicit
|
| 701 |
+
# row indices and index_copy_ instead. This also avoids materializing
|
| 702 |
+
# a full expanded mask for weight matrices.
|
| 703 |
+
active_rows = row_mask.reshape(-1).nonzero(as_tuple=False).flatten()
|
| 704 |
+
if active_rows.numel() == 0:
|
| 705 |
+
continue
|
| 706 |
+
|
| 707 |
+
m_rows = m.index_select(0, active_rows)
|
| 708 |
+
v_rows = v.index_select(0, active_rows)
|
| 709 |
+
g_rows = g.index_select(0, active_rows)
|
| 710 |
+
|
| 711 |
+
new_m_rows = self.beta1 * m_rows + (1.0 - self.beta1) * g_rows
|
| 712 |
+
new_v_rows = self.beta2 * v_rows + (1.0 - self.beta2) * g_rows * g_rows
|
| 713 |
+
update_rows = new_m_rows / (torch.sqrt(new_v_rows) + self.eps)
|
| 714 |
+
p_rows = p.index_select(0, active_rows) - self.lr * update_rows
|
| 715 |
+
|
| 716 |
+
m.index_copy_(0, active_rows, new_m_rows)
|
| 717 |
+
v.index_copy_(0, active_rows, new_v_rows)
|
| 718 |
+
p.index_copy_(0, active_rows, p_rows)
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
# -----------------------------
|
| 722 |
+
# Training utilities
|
| 723 |
+
# -----------------------------
|
| 724 |
+
|
| 725 |
+
@torch.no_grad()
|
| 726 |
+
def estimate_loss(model: nn.Module, corpus: CharCorpus, batch_size: int, eval_iters: int, seed: int) -> Dict[str, float]:
|
| 727 |
+
model.eval()
|
| 728 |
+
configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
|
| 729 |
+
out = {}
|
| 730 |
+
for split in ["train", "val"]:
|
| 731 |
+
losses = []
|
| 732 |
+
gen = make_cpu_generator(seed + (0 if split == "train" else 100000))
|
| 733 |
+
for _ in range(eval_iters):
|
| 734 |
+
x, y = corpus.get_batch(split, batch_size, generator=gen)
|
| 735 |
+
_, loss = model(x, y)
|
| 736 |
+
losses.append(float(loss.item()))
|
| 737 |
+
out[split] = sum(losses) / len(losses)
|
| 738 |
+
model.train()
|
| 739 |
+
return out
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
def dense_audit_pass(model: nn.Module, corpus_batch: Tuple[torch.Tensor, torch.Tensor], opt: MaskedAdam, masker: RowMasker) -> torch.Tensor:
|
| 743 |
+
x, y = corpus_batch
|
| 744 |
+
configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
|
| 745 |
+
opt.zero_grad()
|
| 746 |
+
_, audit_loss = model(x, y)
|
| 747 |
+
audit_loss.backward()
|
| 748 |
+
mass = masker.current_gradient_mass_from_grads()
|
| 749 |
+
opt.zero_grad()
|
| 750 |
+
return mass
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
def sparse_training_backward(
|
| 754 |
+
model: nn.Module,
|
| 755 |
+
corpus_batch: Tuple[torch.Tensor, torch.Tensor],
|
| 756 |
+
opt: MaskedAdam,
|
| 757 |
+
masker: Optional[RowMasker],
|
| 758 |
+
backward_mode: Optional[BackwardMode],
|
| 759 |
+
) -> float:
|
| 760 |
+
x, y = corpus_batch
|
| 761 |
+
opt.zero_grad()
|
| 762 |
+
|
| 763 |
+
if masker is None or backward_mode is None or backward_mode == "masked_optimizer":
|
| 764 |
+
configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
|
| 765 |
+
else:
|
| 766 |
+
configure_sparse_linears(model, masker=masker, enabled=True, backward_mode=backward_mode)
|
| 767 |
+
|
| 768 |
+
_, loss = model(x, y)
|
| 769 |
+
loss.backward()
|
| 770 |
+
configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
|
| 771 |
+
return float(loss.item())
|
| 772 |
+
|
| 773 |
+
|
| 774 |
+
def train_run(
|
| 775 |
+
corpus: CharCorpus,
|
| 776 |
+
args: argparse.Namespace,
|
| 777 |
+
policy: Optional[Policy],
|
| 778 |
+
backward_mode: Optional[BackwardMode],
|
| 779 |
+
active_fraction: float,
|
| 780 |
+
warmup_steps: int,
|
| 781 |
+
explore_fraction: float,
|
| 782 |
+
seed_offset: int,
|
| 783 |
+
) -> Dict[str, float | str]:
|
| 784 |
+
# Same model initialization and same minibatch sequence for every run by default.
|
| 785 |
+
set_seed(args.seed + (seed_offset if args.unpaired_seeds else 0))
|
| 786 |
+
data_gen = make_cpu_generator(args.seed + 12345)
|
| 787 |
+
|
| 788 |
+
dev = corpus.device
|
| 789 |
+
model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev)
|
| 790 |
+
|
| 791 |
+
masker = None
|
| 792 |
+
if policy is not None:
|
| 793 |
+
masker = RowMasker(
|
| 794 |
+
model=model,
|
| 795 |
+
policy=policy,
|
| 796 |
+
active_fraction=active_fraction,
|
| 797 |
+
explore_fraction=explore_fraction,
|
| 798 |
+
mass_beta=args.mass_beta,
|
| 799 |
+
unobserved_decay=args.unobserved_decay,
|
| 800 |
+
warmup_steps=warmup_steps,
|
| 801 |
+
ucb_alpha=args.ucb_alpha,
|
| 802 |
+
mass_init=args.mass_init,
|
| 803 |
+
device=dev,
|
| 804 |
+
)
|
| 805 |
+
opt = MaskedAdam(
|
| 806 |
+
model,
|
| 807 |
+
masker,
|
| 808 |
+
lr=args.lr,
|
| 809 |
+
weight_decay=args.weight_decay,
|
| 810 |
+
freeze_non_linear_when_sparse=args.freeze_non_linear_when_sparse,
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
sums = {
|
| 814 |
+
"cosine": 0.0,
|
| 815 |
+
"norm_ratio": 0.0,
|
| 816 |
+
"top20_mass": 0.0,
|
| 817 |
+
"jacc_oracle": 0.0,
|
| 818 |
+
"stability": 0.0,
|
| 819 |
+
"active_fraction_real": 0.0,
|
| 820 |
+
"coverage": 0.0,
|
| 821 |
+
"avg_obs_count": 0.0,
|
| 822 |
+
"new_active_fraction": 0.0,
|
| 823 |
+
}
|
| 824 |
+
counts = {k: 0 for k in sums}
|
| 825 |
+
|
| 826 |
+
def add_metrics(metrics: Dict[str, float]) -> None:
|
| 827 |
+
for k, v in metrics.items():
|
| 828 |
+
if k in sums:
|
| 829 |
+
sums[k] += float(v)
|
| 830 |
+
counts[k] += 1
|
| 831 |
+
|
| 832 |
+
if args.benchmark_sync:
|
| 833 |
+
sync_device(dev)
|
| 834 |
+
t0 = time.perf_counter()
|
| 835 |
+
|
| 836 |
+
for step in range(args.steps):
|
| 837 |
+
batch = corpus.get_batch("train", args.batch_size, generator=data_gen)
|
| 838 |
+
|
| 839 |
+
if masker is None:
|
| 840 |
+
loss_value = sparse_training_backward(model, batch, opt, masker=None, backward_mode=None)
|
| 841 |
+
opt.step()
|
| 842 |
+
else:
|
| 843 |
+
if step < warmup_steps:
|
| 844 |
+
# Dense bootstrap. Every row is active and every row updates the predictor.
|
| 845 |
+
masker._set_active(torch.ones(masker.n_blocks, dtype=torch.bool, device=dev))
|
| 846 |
+
loss_value = sparse_training_backward(model, batch, opt, masker=masker, backward_mode="masked_optimizer")
|
| 847 |
+
full_mass = masker.current_gradient_mass_from_grads()
|
| 848 |
+
masker.last_full_mass = full_mass.detach().clone()
|
| 849 |
+
add_metrics(masker.audit_metrics_from_mass(full_mass))
|
| 850 |
+
add_metrics(masker.update_predictor_from_observed_mass(full_mass, observed=masker.active))
|
| 851 |
+
opt.step()
|
| 852 |
+
else:
|
| 853 |
+
masker.choose_pre_backward(step)
|
| 854 |
+
|
| 855 |
+
if policy == "oracle_current":
|
| 856 |
+
# Explicit upper bound. Oracle necessarily computes dense gradients to choose rows.
|
| 857 |
+
full_mass = dense_audit_pass(model, batch, opt, masker)
|
| 858 |
+
masker._set_active(masker._topk_mask(full_mass, active_fraction))
|
| 859 |
+
masker.last_full_mass = full_mass.detach().clone()
|
| 860 |
+
add_metrics(masker.audit_metrics_from_mass(full_mass))
|
| 861 |
+
elif args.audit_every > 0 and ((step - warmup_steps) % args.audit_every == 0):
|
| 862 |
+
# Measurement only. Do not update predicted_magnitude/ucb/random with this dense mass.
|
| 863 |
+
full_mass = dense_audit_pass(model, batch, opt, masker)
|
| 864 |
+
add_metrics(masker.audit_metrics_from_mass(full_mass))
|
| 865 |
+
if policy == "stale_current":
|
| 866 |
+
masker.last_full_mass = full_mass.detach().clone()
|
| 867 |
+
|
| 868 |
+
loss_value = sparse_training_backward(model, batch, opt, masker=masker, backward_mode=backward_mode)
|
| 869 |
+
|
| 870 |
+
# Practical selector update: only active rows were observed by the training backward pass.
|
| 871 |
+
observed_mass = masker.current_gradient_mass_from_grads()
|
| 872 |
+
add_metrics(masker.update_predictor_from_observed_mass(observed_mass, observed=masker.active))
|
| 873 |
+
opt.step()
|
| 874 |
+
|
| 875 |
+
if args.benchmark_sync:
|
| 876 |
+
sync_device(dev)
|
| 877 |
+
|
| 878 |
+
if args.verbose and (step % args.eval_interval == 0 or step == args.steps - 1):
|
| 879 |
+
losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters, seed=args.seed + 555)
|
| 880 |
+
name = "dense" if policy is None else f"{policy}/{backward_mode}"
|
| 881 |
+
print(
|
| 882 |
+
f"{name:38s} step={step:5d} warm={warmup_steps:4d} explore={explore_fraction:.2f} "
|
| 883 |
+
f"loss={loss_value:.4f} train={losses['train']:.4f} val={losses['val']:.4f}"
|
| 884 |
+
)
|
| 885 |
+
|
| 886 |
+
if args.benchmark_sync:
|
| 887 |
+
sync_device(dev)
|
| 888 |
+
elapsed_s = time.perf_counter() - t0
|
| 889 |
+
train_tokens = float(args.steps * args.batch_size * args.block_size)
|
| 890 |
+
step_ms = 1000.0 * elapsed_s / max(1, args.steps)
|
| 891 |
+
tokens_per_s = train_tokens / max(elapsed_s, 1e-9)
|
| 892 |
+
|
| 893 |
+
losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters, seed=args.seed + 999)
|
| 894 |
+
row: Dict[str, float | str] = {
|
| 895 |
+
"run": "dense_baseline" if policy is None else policy,
|
| 896 |
+
"mode": "dense" if backward_mode is None else backward_mode,
|
| 897 |
+
"target_active": 1.0 if policy is None else active_fraction,
|
| 898 |
+
"warmup": warmup_steps,
|
| 899 |
+
"explore": explore_fraction if policy is not None else 0.0,
|
| 900 |
+
"train_loss": losses["train"],
|
| 901 |
+
"val_loss": losses["val"],
|
| 902 |
+
"elapsed_s": elapsed_s,
|
| 903 |
+
"step_ms": step_ms,
|
| 904 |
+
"tokens_per_s": tokens_per_s,
|
| 905 |
+
}
|
| 906 |
+
if masker is None:
|
| 907 |
+
row.update({
|
| 908 |
+
"cosine": float("nan"),
|
| 909 |
+
"norm_ratio": float("nan"),
|
| 910 |
+
"top20_mass": float("nan"),
|
| 911 |
+
"jacc_oracle": float("nan"),
|
| 912 |
+
"stability": float("nan"),
|
| 913 |
+
"active_fraction_real": 1.0,
|
| 914 |
+
"coverage": float("nan"),
|
| 915 |
+
"avg_obs_count": float("nan"),
|
| 916 |
+
"new_active_fraction": float("nan"),
|
| 917 |
+
})
|
| 918 |
+
else:
|
| 919 |
+
for k in sums:
|
| 920 |
+
row[k] = (sums[k] / counts[k]) if counts[k] > 0 else float("nan")
|
| 921 |
+
return row
|
| 922 |
+
|
| 923 |
+
def print_summary(rows: List[Dict[str, float | str]]) -> None:
|
| 924 |
+
print("\nSummary")
|
| 925 |
+
header = (
|
| 926 |
+
f"{'run':>22s} {'mode':>19s} {'target':>7s} {'actual':>7s} {'warm':>5s} {'expl':>5s} "
|
| 927 |
+
f"{'val':>8s} {'train':>8s} {'ms':>8s} {'tok/s':>9s} {'cos':>7s} {'top20':>7s} {'jacc':>7s} "
|
| 928 |
+
f"{'stable':>7s} {'cover':>7s} {'new':>7s}"
|
| 929 |
+
)
|
| 930 |
+
print(header)
|
| 931 |
+
print("-" * len(header))
|
| 932 |
+
for r in rows:
|
| 933 |
+
print(
|
| 934 |
+
f"{str(r['run']):>22s} "
|
| 935 |
+
f"{str(r['mode']):>19s} "
|
| 936 |
+
f"{float(r['target_active']):7.3f} "
|
| 937 |
+
f"{float(r['active_fraction_real']):7.3f} "
|
| 938 |
+
f"{int(float(r['warmup'])):5d} "
|
| 939 |
+
f"{float(r['explore']):5.2f} "
|
| 940 |
+
f"{float(r['val_loss']):8.4f} "
|
| 941 |
+
f"{float(r['train_loss']):8.4f} "
|
| 942 |
+
f"{float(r.get('step_ms', float('nan'))):8.2f} "
|
| 943 |
+
f"{float(r.get('tokens_per_s', float('nan'))):9.0f} "
|
| 944 |
+
f"{float(r['cosine']):7.3f} "
|
| 945 |
+
f"{float(r['top20_mass']):7.3f} "
|
| 946 |
+
f"{float(r['jacc_oracle']):7.3f} "
|
| 947 |
+
f"{float(r['stability']):7.3f} "
|
| 948 |
+
f"{float(r['coverage']):7.3f} "
|
| 949 |
+
f"{float(r['new_active_fraction']):7.3f}"
|
| 950 |
+
)
|
| 951 |
+
|
| 952 |
+
|
| 953 |
+
def parse_args() -> argparse.Namespace:
|
| 954 |
+
p = argparse.ArgumentParser()
|
| 955 |
+
p.add_argument("--text_path", type=str, default=None)
|
| 956 |
+
p.add_argument("--synthetic_sentences", type=int, default=12000)
|
| 957 |
+
p.add_argument("--steps", type=int, default=1000)
|
| 958 |
+
p.add_argument("--quick", action="store_true")
|
| 959 |
+
p.add_argument("--batch_size", type=int, default=32)
|
| 960 |
+
p.add_argument("--block_size", type=int, default=64)
|
| 961 |
+
p.add_argument("--n_layer", type=int, default=2)
|
| 962 |
+
p.add_argument("--n_head", type=int, default=4)
|
| 963 |
+
p.add_argument("--n_embd", type=int, default=64)
|
| 964 |
+
p.add_argument("--dropout", type=float, default=0.0)
|
| 965 |
+
p.add_argument("--lr", type=float, default=3e-4)
|
| 966 |
+
p.add_argument("--weight_decay", type=float, default=0.0)
|
| 967 |
+
p.add_argument("--active_fractions", type=float, nargs="+", default=[0.05, 0.02])
|
| 968 |
+
p.add_argument("--policies", type=str, nargs="+", default=["oracle_current", "predicted_magnitude", "random"])
|
| 969 |
+
p.add_argument(
|
| 970 |
+
"--backward_modes",
|
| 971 |
+
type=str,
|
| 972 |
+
nargs="+",
|
| 973 |
+
default=["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"],
|
| 974 |
+
)
|
| 975 |
+
p.add_argument("--explore_fractions", type=float, nargs="+", default=[0.0])
|
| 976 |
+
p.add_argument("--warmup_steps_list", type=int, nargs="+", default=[5])
|
| 977 |
+
p.add_argument("--mass_beta", type=float, default=0.95)
|
| 978 |
+
p.add_argument("--unobserved_decay", type=float, default=1.0)
|
| 979 |
+
p.add_argument("--mass_init", type=float, default=0.0)
|
| 980 |
+
p.add_argument("--ucb_alpha", type=float, default=1.0)
|
| 981 |
+
p.add_argument("--freeze_non_linear_when_sparse", action="store_true")
|
| 982 |
+
p.add_argument("--eval_interval", type=int, default=200)
|
| 983 |
+
p.add_argument("--eval_iters", type=int, default=20)
|
| 984 |
+
p.add_argument("--seed", type=int, default=7)
|
| 985 |
+
p.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "cuda", "mps"])
|
| 986 |
+
p.add_argument("--audit_every", type=int, default=0, help="Dense audit interval after warmup. 0 disables audits except oracle_current.")
|
| 987 |
+
p.add_argument("--kernel_backend", type=str, default="torch", choices=["torch", "metal"], help="Use PyTorch fallback or the optional Metal active-row Linear backward kernel.")
|
| 988 |
+
p.add_argument("--benchmark_sync", action="store_true", help="Synchronize after every train step for fair wall-clock benchmarking on async devices.")
|
| 989 |
+
p.add_argument("--unpaired_seeds", action="store_true", help="Use different init seeds per run instead of paired seeds.")
|
| 990 |
+
p.add_argument("--verbose", action="store_true")
|
| 991 |
+
return p.parse_args()
|
| 992 |
+
|
| 993 |
+
|
| 994 |
+
def main() -> None:
|
| 995 |
+
global USE_METAL_KERNEL
|
| 996 |
+
args = parse_args()
|
| 997 |
+
USE_METAL_KERNEL = args.kernel_backend == "metal"
|
| 998 |
+
if USE_METAL_KERNEL and sparse_linear_metal is None:
|
| 999 |
+
raise RuntimeError(
|
| 1000 |
+
"--kernel_backend metal requested, but sparse_linear_metal is not importable. "
|
| 1001 |
+
"Build in the repo venv (PyTorch must be importable during setup): "
|
| 1002 |
+
"cd sparse_linear_v10_metal && python3 -m pip install -e . --no-build-isolation"
|
| 1003 |
+
)
|
| 1004 |
+
if args.quick:
|
| 1005 |
+
args.steps = 40
|
| 1006 |
+
args.eval_iters = 2
|
| 1007 |
+
args.batch_size = 8
|
| 1008 |
+
args.block_size = 32
|
| 1009 |
+
args.n_layer = 1
|
| 1010 |
+
args.n_embd = 32
|
| 1011 |
+
args.n_head = 4
|
| 1012 |
+
args.synthetic_sentences = 1200
|
| 1013 |
+
args.active_fractions = [0.05]
|
| 1014 |
+
args.policies = ["predicted_magnitude", "random"]
|
| 1015 |
+
args.backward_modes = ["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
|
| 1016 |
+
args.explore_fractions = [0.0]
|
| 1017 |
+
args.warmup_steps_list = [5]
|
| 1018 |
+
args.audit_every = 10
|
| 1019 |
+
|
| 1020 |
+
valid_policies = {"predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"}
|
| 1021 |
+
valid_modes = {"masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"}
|
| 1022 |
+
for pol in args.policies:
|
| 1023 |
+
if pol not in valid_policies:
|
| 1024 |
+
raise ValueError(f"Unknown policy {pol!r}. Valid policies: {sorted(valid_policies)}")
|
| 1025 |
+
for mode in args.backward_modes:
|
| 1026 |
+
if mode not in valid_modes:
|
| 1027 |
+
raise ValueError(f"Unknown backward mode {mode!r}. Valid modes: {sorted(valid_modes)}")
|
| 1028 |
+
|
| 1029 |
+
set_seed(args.seed)
|
| 1030 |
+
dev = args.device if args.device != "auto" else default_device()
|
| 1031 |
+
print(f"device={dev}")
|
| 1032 |
+
corpus = CharCorpus(load_text(args), args.block_size, dev)
|
| 1033 |
+
print(f"vocab_size={corpus.vocab_size} train_tokens={len(corpus.train_data)} val_tokens={len(corpus.val_data)}")
|
| 1034 |
+
print(f"policies={args.policies}")
|
| 1035 |
+
print(f"backward_modes={args.backward_modes}")
|
| 1036 |
+
print(f"active_fractions={args.active_fractions}")
|
| 1037 |
+
print(f"warmup_steps_list={args.warmup_steps_list} explore_fractions={args.explore_fractions}")
|
| 1038 |
+
print(f"mass_init={args.mass_init} mass_beta={args.mass_beta} ucb_alpha={args.ucb_alpha}")
|
| 1039 |
+
print(f"paired_seeds={not args.unpaired_seeds}")
|
| 1040 |
+
print(f"kernel_backend={args.kernel_backend} benchmark_sync={args.benchmark_sync} metal_available={sparse_linear_metal is not None}")
|
| 1041 |
+
print(f"audit_every={args.audit_every} (0 means no dense audit after warmup, except oracle_current)")
|
| 1042 |
+
|
| 1043 |
+
tmp_model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev)
|
| 1044 |
+
total_params, linear_params, linear_frac = parameter_fractions(tmp_model)
|
| 1045 |
+
del tmp_model
|
| 1046 |
+
print(f"params total={total_params} linear={linear_params} linear_fraction={linear_frac:.3f}")
|
| 1047 |
+
if args.freeze_non_linear_when_sparse:
|
| 1048 |
+
print("freeze_non_linear_when_sparse=True: embeddings/layernorm/etc. are frozen in sparse runs")
|
| 1049 |
+
else:
|
| 1050 |
+
print("freeze_non_linear_when_sparse=False: non-Linear params are still updated densely")
|
| 1051 |
+
|
| 1052 |
+
if args.dropout != 0.0:
|
| 1053 |
+
print("warning: dropout is nonzero; dense audit and sparse training passes may see different dropout masks")
|
| 1054 |
+
|
| 1055 |
+
rows: List[Dict[str, float | str]] = []
|
| 1056 |
+
print("\nRunning dense baseline")
|
| 1057 |
+
rows.append(
|
| 1058 |
+
train_run(
|
| 1059 |
+
corpus,
|
| 1060 |
+
args,
|
| 1061 |
+
policy=None,
|
| 1062 |
+
backward_mode=None,
|
| 1063 |
+
active_fraction=1.0,
|
| 1064 |
+
warmup_steps=0,
|
| 1065 |
+
explore_fraction=0.0,
|
| 1066 |
+
seed_offset=0,
|
| 1067 |
+
)
|
| 1068 |
+
)
|
| 1069 |
+
|
| 1070 |
+
seed_offset = 100
|
| 1071 |
+
for mode in args.backward_modes:
|
| 1072 |
+
for af in args.active_fractions:
|
| 1073 |
+
for pol in args.policies:
|
| 1074 |
+
explore_values = args.explore_fractions if pol in {"predicted_magnitude", "ucb_magnitude"} else [0.0]
|
| 1075 |
+
for warmup in args.warmup_steps_list:
|
| 1076 |
+
for explore in explore_values:
|
| 1077 |
+
print(
|
| 1078 |
+
f"\nRunning mode={mode}, policy={pol}, "
|
| 1079 |
+
f"active_fraction={af:.3f}, warmup={warmup}, explore={explore:.2f}"
|
| 1080 |
+
)
|
| 1081 |
+
rows.append(
|
| 1082 |
+
train_run(
|
| 1083 |
+
corpus,
|
| 1084 |
+
args,
|
| 1085 |
+
policy=pol, # type: ignore[arg-type]
|
| 1086 |
+
backward_mode=mode, # type: ignore[arg-type]
|
| 1087 |
+
active_fraction=af,
|
| 1088 |
+
warmup_steps=warmup,
|
| 1089 |
+
explore_fraction=explore,
|
| 1090 |
+
seed_offset=seed_offset,
|
| 1091 |
+
)
|
| 1092 |
+
)
|
| 1093 |
+
seed_offset += 1
|
| 1094 |
+
|
| 1095 |
+
print_summary(rows)
|
| 1096 |
+
|
| 1097 |
+
print("\nNotes")
|
| 1098 |
+
print(" masked_optimizer is the v7-style dense-backward simulation control.")
|
| 1099 |
+
print(" sparse_dW_full_dX uses custom Linear backward: sparse weight/bias grads, full input gradient.")
|
| 1100 |
+
print(" sparse_dW_sparse_dX uses custom Linear backward: sparse weight/bias grads and sparse input gradient.")
|
| 1101 |
+
print(" oracle_current uses dense audit gradients to choose rows; it is an upper bound.")
|
| 1102 |
+
print(" predicted_magnitude uses EMA mass from active/observed rows only.")
|
| 1103 |
+
print(" random is the sparse-support control.")
|
| 1104 |
+
print(" v10 does not compute dense audit gradients after warmup unless --audit_every > 0, except oracle_current.")
|
| 1105 |
+
print(" predicted_magnitude updates EMA statistics only from active rows observed by the training backward pass.")
|
| 1106 |
+
print(" kernel_backend=torch uses the portable PyTorch fallback; kernel_backend=metal uses sparse_linear_metal if installed.")
|
| 1107 |
+
print(" Use --benchmark_sync on MPS/CUDA for honest step_ms/tokens_per_s comparisons.")
|
| 1108 |
+
print(" cosine/top20/jacc require --audit_every > 0, otherwise there is no dense reference gradient.")
|
| 1109 |
+
|
| 1110 |
+
|
| 1111 |
+
if __name__ == "__main__":
|
| 1112 |
+
main()
|
experiments/sparse_linear_v11_gather_vs_metal/README.md
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Sparse Transformer v10: Metal-backed active-row Linear backward benchmark
|
| 2 |
+
|
| 3 |
+
This bundle is the first empirical optimization test rather than only a correctness prototype.
|
| 4 |
+
It compares dense Transformer training against sparse active-row backward variants and reports
|
| 5 |
+
validation loss plus wall-clock metrics (`step_ms`, `tokens_per_s`).
|
| 6 |
+
|
| 7 |
+
## Files
|
| 8 |
+
|
| 9 |
+
- `sparse_transformer_v10.py` — training + benchmark harness.
|
| 10 |
+
- `sparse_linear.metal` — Metal kernels for active-row `dW/db` and sparse `dX`.
|
| 11 |
+
- `sparse_linear_ops.mm` — PyTorch/MPS extension glue.
|
| 12 |
+
- `setup.py` — builds the extension and compiles the `.metallib`.
|
| 13 |
+
|
| 14 |
+
## Install the Metal extension
|
| 15 |
+
|
| 16 |
+
From this directory on macOS with PyTorch MPS available:
|
| 17 |
+
|
| 18 |
+
```bash
|
| 19 |
+
python3 -m pip install -e .
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
This uses `xcrun metal` and `xcrun metallib`, so Xcode command-line tools are required.
|
| 23 |
+
|
| 24 |
+
## Correctness/sanity run with PyTorch fallback
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
python3 sparse_transformer_v10.py \
|
| 28 |
+
--device mps \
|
| 29 |
+
--steps 2000 \
|
| 30 |
+
--active_fractions 0.05 0.02 \
|
| 31 |
+
--warmup_steps_list 5 \
|
| 32 |
+
--policies predicted_magnitude random \
|
| 33 |
+
--backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
|
| 34 |
+
--audit_every 0 \
|
| 35 |
+
--kernel_backend torch \
|
| 36 |
+
--benchmark_sync
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
## Metal benchmark run
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
python3 sparse_transformer_v10.py \
|
| 43 |
+
--device mps \
|
| 44 |
+
--steps 2000 \
|
| 45 |
+
--active_fractions 0.05 0.02 \
|
| 46 |
+
--warmup_steps_list 5 \
|
| 47 |
+
--policies predicted_magnitude random \
|
| 48 |
+
--backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
|
| 49 |
+
--audit_every 0 \
|
| 50 |
+
--kernel_backend metal \
|
| 51 |
+
--benchmark_sync
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
## Empirical pass/fail criteria
|
| 55 |
+
|
| 56 |
+
A genuine optimization would show:
|
| 57 |
+
|
| 58 |
+
1. `predicted_magnitude` validation loss close to the PyTorch fallback v9/v10 result.
|
| 59 |
+
2. `predicted_magnitude` much better than `random` at the same active fraction.
|
| 60 |
+
3. `--kernel_backend metal` has lower `step_ms` or higher `tokens_per_s` than `--kernel_backend torch` and ideally dense baseline.
|
| 61 |
+
|
| 62 |
+
The first Metal kernels are intentionally simple and fp32-only. They are meant to prove or disprove the acceleration path before investing in tiled half-precision kernels.
|
experiments/sparse_linear_v11_gather_vs_metal/input.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
experiments/sparse_linear_v11_gather_vs_metal/setup.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup
|
| 2 |
+
from torch.utils.cpp_extension import CppExtension, BuildExtension
|
| 3 |
+
|
| 4 |
+
setup(
|
| 5 |
+
name='sparse_linear',
|
| 6 |
+
ext_modules=[
|
| 7 |
+
CppExtension(
|
| 8 |
+
name='sparse_linear',
|
| 9 |
+
sources=['sparse_linear_ops.mm'],
|
| 10 |
+
extra_compile_args=['-ObjC++'],
|
| 11 |
+
extra_link_args=['-framework', 'Metal', '-framework', 'Foundation']
|
| 12 |
+
)
|
| 13 |
+
],
|
| 14 |
+
cmdclass={'build_ext': BuildExtension}
|
| 15 |
+
)
|
experiments/sparse_linear_v11_gather_vs_metal/sparse_linear.metal
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <metal_stdlib>
|
| 2 |
+
using namespace metal;
|
| 3 |
+
|
| 4 |
+
struct SparseLinearParams {
|
| 5 |
+
uint32_t N; // flattened batch/token rows
|
| 6 |
+
uint32_t In; // in_features
|
| 7 |
+
uint32_t Out; // out_features
|
| 8 |
+
uint32_t K; // active rows count
|
| 9 |
+
};
|
| 10 |
+
|
| 11 |
+
kernel void sparse_linear_grad_w_float(
|
| 12 |
+
device const float* x [[buffer(0)]], // [N, In]
|
| 13 |
+
device const float* gy [[buffer(1)]], // [N, Out]
|
| 14 |
+
device const int64_t* active [[buffer(2)]], // [K]
|
| 15 |
+
device float* grad_w [[buffer(3)]], // [Out, In], zeroed by caller
|
| 16 |
+
device float* grad_b [[buffer(4)]], // [Out], zeroed by caller
|
| 17 |
+
constant SparseLinearParams& p [[buffer(5)]],
|
| 18 |
+
uint2 tid [[thread_position_in_grid]]) {
|
| 19 |
+
uint k = tid.y;
|
| 20 |
+
uint c = tid.x;
|
| 21 |
+
if (k >= p.K || c >= p.In) return;
|
| 22 |
+
|
| 23 |
+
int64_t row64 = active[k];
|
| 24 |
+
if (row64 < 0 || row64 >= (int64_t)p.Out) return;
|
| 25 |
+
uint row = (uint)row64;
|
| 26 |
+
|
| 27 |
+
float acc = 0.0f;
|
| 28 |
+
for (uint n = 0; n < p.N; ++n) {
|
| 29 |
+
acc += gy[n * p.Out + row] * x[n * p.In + c];
|
| 30 |
+
}
|
| 31 |
+
grad_w[row * p.In + c] = acc;
|
| 32 |
+
|
| 33 |
+
// One thread per active row computes bias.
|
| 34 |
+
if (c == 0) {
|
| 35 |
+
float bacc = 0.0f;
|
| 36 |
+
for (uint n = 0; n < p.N; ++n) {
|
| 37 |
+
bacc += gy[n * p.Out + row];
|
| 38 |
+
}
|
| 39 |
+
grad_b[row] = bacc;
|
| 40 |
+
}
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
kernel void sparse_linear_grad_x_float(
|
| 44 |
+
device const float* gy [[buffer(0)]], // [N, Out]
|
| 45 |
+
device const float* weight [[buffer(1)]], // [Out, In]
|
| 46 |
+
device const int64_t* active [[buffer(2)]], // [K]
|
| 47 |
+
device float* grad_x [[buffer(3)]], // [N, In]
|
| 48 |
+
constant SparseLinearParams& p [[buffer(4)]],
|
| 49 |
+
uint2 tid [[thread_position_in_grid]]) {
|
| 50 |
+
uint c = tid.x;
|
| 51 |
+
uint n = tid.y;
|
| 52 |
+
if (n >= p.N || c >= p.In) return;
|
| 53 |
+
|
| 54 |
+
float acc = 0.0f;
|
| 55 |
+
for (uint k = 0; k < p.K; ++k) {
|
| 56 |
+
int64_t row64 = active[k];
|
| 57 |
+
if (row64 < 0 || row64 >= (int64_t)p.Out) continue;
|
| 58 |
+
uint row = (uint)row64;
|
| 59 |
+
acc += gy[n * p.Out + row] * weight[row * p.In + c];
|
| 60 |
+
}
|
| 61 |
+
grad_x[n * p.In + c] = acc;
|
| 62 |
+
}
|
experiments/sparse_linear_v11_gather_vs_metal/sparse_linear_ops.mm
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include <ATen/ATen.h>
|
| 3 |
+
#include <ATen/mps/MPSStream.h>
|
| 4 |
+
#include <ATen/mps/MPSDevice.h>
|
| 5 |
+
#include <c10/util/Exception.h>
|
| 6 |
+
#include <filesystem>
|
| 7 |
+
#include <dlfcn.h>
|
| 8 |
+
#import <Metal/Metal.h>
|
| 9 |
+
#import <Foundation/Foundation.h>
|
| 10 |
+
#include <mutex>
|
| 11 |
+
|
| 12 |
+
namespace fs = std::filesystem;
|
| 13 |
+
|
| 14 |
+
namespace {
|
| 15 |
+
struct SparseLinearParams {
|
| 16 |
+
uint32_t N;
|
| 17 |
+
uint32_t In;
|
| 18 |
+
uint32_t Out;
|
| 19 |
+
uint32_t K;
|
| 20 |
+
};
|
| 21 |
+
|
| 22 |
+
static id<MTLLibrary> g_lib = nil;
|
| 23 |
+
static id<MTLComputePipelineState> g_pipeline_grad_w = nil;
|
| 24 |
+
static id<MTLComputePipelineState> g_pipeline_grad_x = nil;
|
| 25 |
+
static std::mutex g_mutex;
|
| 26 |
+
|
| 27 |
+
static std::string metallib_path_for_this_module() {
|
| 28 |
+
Dl_info info;
|
| 29 |
+
if (dladdr((void*)&metallib_path_for_this_module, &info) == 0 || info.dli_fname == nullptr) return std::string();
|
| 30 |
+
fs::path so_path(info.dli_fname);
|
| 31 |
+
return (so_path.parent_path() / "sparse_linear_ops.metallib").string();
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
static void ensure_library_locked(id<MTLDevice> device) {
|
| 35 |
+
if (g_lib != nil) return;
|
| 36 |
+
std::string path = metallib_path_for_this_module();
|
| 37 |
+
TORCH_CHECK(!path.empty(), "sparse_linear_ops: failed to locate extension path via dladdr");
|
| 38 |
+
NSString* ns_path = [NSString stringWithUTF8String:path.c_str()];
|
| 39 |
+
NSURL* url = [NSURL fileURLWithPath:ns_path];
|
| 40 |
+
NSError* err = nil;
|
| 41 |
+
g_lib = [device newLibraryWithURL:url error:&err];
|
| 42 |
+
if (g_lib == nil) {
|
| 43 |
+
const char* msg = err ? [[err localizedDescription] UTF8String] : "unknown error";
|
| 44 |
+
TORCH_CHECK(false, "sparse_linear_ops: failed to load metallib at ", path, ": ", msg);
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
static id<MTLComputePipelineState> ensure_pipeline(id<MTLDevice> device, id<MTLComputePipelineState>* pipeline, const char* fn_name) {
|
| 49 |
+
std::lock_guard<std::mutex> lock(g_mutex);
|
| 50 |
+
ensure_library_locked(device);
|
| 51 |
+
if (*pipeline != nil) return *pipeline;
|
| 52 |
+
NSString* ns_fn = [NSString stringWithUTF8String:fn_name];
|
| 53 |
+
id<MTLFunction> fn = [g_lib newFunctionWithName:ns_fn];
|
| 54 |
+
TORCH_CHECK(fn != nil, "sparse_linear_ops: function `", fn_name, "` not found in metallib");
|
| 55 |
+
NSError* err = nil;
|
| 56 |
+
*pipeline = [device newComputePipelineStateWithFunction:fn error:&err];
|
| 57 |
+
if (*pipeline == nil) {
|
| 58 |
+
const char* msg = err ? [[err localizedDescription] UTF8String] : "unknown error";
|
| 59 |
+
TORCH_CHECK(false, "sparse_linear_ops: failed to create pipeline for ", fn_name, ": ", msg);
|
| 60 |
+
}
|
| 61 |
+
return *pipeline;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
static inline id<MTLBuffer> storage_as_mtlbuffer(const at::Tensor& t) {
|
| 65 |
+
void* ctx = t.storage().data_ptr().get_context();
|
| 66 |
+
TORCH_CHECK(ctx != nullptr, "sparse_linear_ops: expected MPS tensor storage with MTLBuffer context");
|
| 67 |
+
return (__bridge id<MTLBuffer>)ctx;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
static inline NSUInteger storage_offset_bytes(const at::Tensor& t) {
|
| 71 |
+
return (NSUInteger)(t.storage_offset() * (int64_t)t.element_size());
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
static void check_mps_float_contig(const at::Tensor& t, const char* name) {
|
| 75 |
+
TORCH_CHECK(t.device().is_mps(), name, " must be on MPS");
|
| 76 |
+
TORCH_CHECK(t.dtype() == at::kFloat, name, " must be float32 for v10 kernel");
|
| 77 |
+
TORCH_CHECK(t.is_contiguous(), name, " must be contiguous");
|
| 78 |
+
}
|
| 79 |
+
} // namespace
|
| 80 |
+
|
| 81 |
+
std::vector<at::Tensor> sparse_linear_grad_wb(at::Tensor x2d, at::Tensor gy2d, at::Tensor active_idx, int64_t out_features) {
|
| 82 |
+
check_mps_float_contig(x2d, "x2d");
|
| 83 |
+
check_mps_float_contig(gy2d, "gy2d");
|
| 84 |
+
TORCH_CHECK(active_idx.device().is_mps(), "active_idx must be on MPS");
|
| 85 |
+
TORCH_CHECK(active_idx.dtype() == at::kLong, "active_idx must be int64");
|
| 86 |
+
TORCH_CHECK(active_idx.is_contiguous(), "active_idx must be contiguous");
|
| 87 |
+
TORCH_CHECK(x2d.dim() == 2 && gy2d.dim() == 2 && active_idx.dim() == 1, "bad dims");
|
| 88 |
+
TORCH_CHECK(x2d.size(0) == gy2d.size(0), "x2d and gy2d N mismatch");
|
| 89 |
+
TORCH_CHECK(gy2d.size(1) == out_features, "gy2d width must equal out_features");
|
| 90 |
+
|
| 91 |
+
int64_t N = x2d.size(0);
|
| 92 |
+
int64_t In = x2d.size(1);
|
| 93 |
+
int64_t Out = out_features;
|
| 94 |
+
int64_t K = active_idx.numel();
|
| 95 |
+
auto grad_w = at::zeros({Out, In}, x2d.options());
|
| 96 |
+
auto grad_b = at::zeros({Out}, x2d.options());
|
| 97 |
+
if (K == 0) return {grad_w, grad_b};
|
| 98 |
+
|
| 99 |
+
id<MTLDevice> device = (id<MTLDevice>)at::mps::MPSDevice::getInstance()->device();
|
| 100 |
+
id<MTLComputePipelineState> pipeline = ensure_pipeline(device, &g_pipeline_grad_w, "sparse_linear_grad_w_float");
|
| 101 |
+
at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
|
| 102 |
+
TORCH_CHECK(stream != nullptr, "failed to get current MPS stream");
|
| 103 |
+
stream->endKernelCoalescing();
|
| 104 |
+
id<MTLComputeCommandEncoder> encoder = (id<MTLComputeCommandEncoder>)stream->commandEncoder();
|
| 105 |
+
TORCH_CHECK(encoder != nil, "failed to get MTLComputeCommandEncoder");
|
| 106 |
+
[encoder setComputePipelineState:pipeline];
|
| 107 |
+
|
| 108 |
+
auto set_tensor = [&](const at::Tensor& t, int idx) {
|
| 109 |
+
[encoder setBuffer:storage_as_mtlbuffer(t) offset:storage_offset_bytes(t) atIndex:(NSUInteger)idx];
|
| 110 |
+
};
|
| 111 |
+
set_tensor(x2d, 0);
|
| 112 |
+
set_tensor(gy2d, 1);
|
| 113 |
+
set_tensor(active_idx, 2);
|
| 114 |
+
set_tensor(grad_w, 3);
|
| 115 |
+
set_tensor(grad_b, 4);
|
| 116 |
+
SparseLinearParams prm{(uint32_t)N, (uint32_t)In, (uint32_t)Out, (uint32_t)K};
|
| 117 |
+
[encoder setBytes:&prm length:sizeof(SparseLinearParams) atIndex:5];
|
| 118 |
+
|
| 119 |
+
MTLSize tg = MTLSizeMake(16, 16, 1);
|
| 120 |
+
MTLSize grid = MTLSizeMake((NSUInteger)((In + 15) / 16), (NSUInteger)((K + 15) / 16), 1);
|
| 121 |
+
[encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg];
|
| 122 |
+
stream->endKernelCoalescing();
|
| 123 |
+
return {grad_w, grad_b};
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
at::Tensor sparse_linear_grad_x(at::Tensor gy2d, at::Tensor weight, at::Tensor active_idx) {
|
| 127 |
+
check_mps_float_contig(gy2d, "gy2d");
|
| 128 |
+
check_mps_float_contig(weight, "weight");
|
| 129 |
+
TORCH_CHECK(active_idx.device().is_mps(), "active_idx must be on MPS");
|
| 130 |
+
TORCH_CHECK(active_idx.dtype() == at::kLong, "active_idx must be int64");
|
| 131 |
+
TORCH_CHECK(active_idx.is_contiguous(), "active_idx must be contiguous");
|
| 132 |
+
TORCH_CHECK(gy2d.dim() == 2 && weight.dim() == 2 && active_idx.dim() == 1, "bad dims");
|
| 133 |
+
int64_t N = gy2d.size(0);
|
| 134 |
+
int64_t Out = gy2d.size(1);
|
| 135 |
+
int64_t In = weight.size(1);
|
| 136 |
+
int64_t K = active_idx.numel();
|
| 137 |
+
TORCH_CHECK(weight.size(0) == Out, "weight out_features must match gy2d width");
|
| 138 |
+
auto grad_x = at::zeros({N, In}, gy2d.options());
|
| 139 |
+
if (K == 0) return grad_x;
|
| 140 |
+
|
| 141 |
+
id<MTLDevice> device = (id<MTLDevice>)at::mps::MPSDevice::getInstance()->device();
|
| 142 |
+
id<MTLComputePipelineState> pipeline = ensure_pipeline(device, &g_pipeline_grad_x, "sparse_linear_grad_x_float");
|
| 143 |
+
at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
|
| 144 |
+
TORCH_CHECK(stream != nullptr, "failed to get current MPS stream");
|
| 145 |
+
stream->endKernelCoalescing();
|
| 146 |
+
id<MTLComputeCommandEncoder> encoder = (id<MTLComputeCommandEncoder>)stream->commandEncoder();
|
| 147 |
+
TORCH_CHECK(encoder != nil, "failed to get MTLComputeCommandEncoder");
|
| 148 |
+
[encoder setComputePipelineState:pipeline];
|
| 149 |
+
auto set_tensor = [&](const at::Tensor& t, int idx) {
|
| 150 |
+
[encoder setBuffer:storage_as_mtlbuffer(t) offset:storage_offset_bytes(t) atIndex:(NSUInteger)idx];
|
| 151 |
+
};
|
| 152 |
+
set_tensor(gy2d, 0);
|
| 153 |
+
set_tensor(weight, 1);
|
| 154 |
+
set_tensor(active_idx, 2);
|
| 155 |
+
set_tensor(grad_x, 3);
|
| 156 |
+
SparseLinearParams prm{(uint32_t)N, (uint32_t)In, (uint32_t)Out, (uint32_t)K};
|
| 157 |
+
[encoder setBytes:&prm length:sizeof(SparseLinearParams) atIndex:4];
|
| 158 |
+
MTLSize tg = MTLSizeMake(16, 16, 1);
|
| 159 |
+
MTLSize grid = MTLSizeMake((NSUInteger)((In + 15) / 16), (NSUInteger)((N + 15) / 16), 1);
|
| 160 |
+
[encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg];
|
| 161 |
+
stream->endKernelCoalescing();
|
| 162 |
+
return grad_x;
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 166 |
+
m.def("sparse_linear_grad_wb", &sparse_linear_grad_wb, "Sparse active-row Linear dW/db (Metal/MPS, fp32)");
|
| 167 |
+
m.def("sparse_linear_grad_x", &sparse_linear_grad_x, "Sparse active-row Linear dX (Metal/MPS, fp32)");
|
| 168 |
+
}
|
experiments/sparse_linear_v11_gather_vs_metal/sparse_transformer_v11.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sparse Transformer v12: Hardware-Sympathetic Chunked Sparsity.
|
| 3 |
+
|
| 4 |
+
This version groups rows into hardware-friendly "Chunks" (e.g., 64 rows per chunk).
|
| 5 |
+
By selecting entire chunks, PyTorch can use zero-copy Strided Views. This completely
|
| 6 |
+
bypasses the slow index_select/gather memory copying and feeds data directly into
|
| 7 |
+
the AMX / Tensor Cores at native dense speeds.
|
| 8 |
+
|
| 9 |
+
Run:
|
| 10 |
+
python3 sparse_transformer_v12.py --device mps --benchmark_sync
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import math
|
| 15 |
+
import random
|
| 16 |
+
import time
|
| 17 |
+
from typing import Dict, List, Literal, Optional, Tuple
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
torch.set_num_threads(1)
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
|
| 25 |
+
def sync_device(device: str) -> None:
|
| 26 |
+
if device == "cuda" and torch.cuda.is_available():
|
| 27 |
+
torch.cuda.synchronize()
|
| 28 |
+
elif device == "mps" and hasattr(torch, "mps"):
|
| 29 |
+
torch.mps.synchronize()
|
| 30 |
+
|
| 31 |
+
Policy = Literal["predicted_magnitude", "oracle_current", "random"]
|
| 32 |
+
BackwardMode = Literal["dense_baseline", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
|
| 33 |
+
|
| 34 |
+
def set_seed(seed: int) -> None:
|
| 35 |
+
random.seed(seed)
|
| 36 |
+
torch.manual_seed(seed)
|
| 37 |
+
|
| 38 |
+
def make_cpu_generator(seed: int) -> torch.Generator:
|
| 39 |
+
gen = torch.Generator(device="cpu")
|
| 40 |
+
gen.manual_seed(seed)
|
| 41 |
+
return gen
|
| 42 |
+
|
| 43 |
+
# -----------------------------
|
| 44 |
+
# Data
|
| 45 |
+
# -----------------------------
|
| 46 |
+
def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
|
| 47 |
+
rng = random.Random(seed)
|
| 48 |
+
words =["ada", "turing", "grace", "lovelace", "gradients", "tokens", "circuits", "features", "boldly", "strangely"]
|
| 49 |
+
return "\n".join(" ".join(rng.choices(words, k=rng.randint(4, 10))) + "." for _ in range(n_sentences))
|
| 50 |
+
|
| 51 |
+
class CharCorpus:
|
| 52 |
+
def __init__(self, text: str, block_size: int, device: str):
|
| 53 |
+
chars = sorted(set(text))
|
| 54 |
+
self.stoi = {ch: i for i, ch in enumerate(chars)}
|
| 55 |
+
self.itos = {i: ch for ch, i in self.stoi.items()}
|
| 56 |
+
self.vocab_size = len(chars)
|
| 57 |
+
self.block_size = block_size
|
| 58 |
+
self.device = device
|
| 59 |
+
data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
|
| 60 |
+
self.train_data = data[:int(0.9 * len(data))]
|
| 61 |
+
self.val_data = data[int(0.9 * len(data)):]
|
| 62 |
+
|
| 63 |
+
def get_batch(self, split: str, batch_size: int, generator: Optional[torch.Generator] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 64 |
+
data = self.train_data if split == "train" else self.val_data
|
| 65 |
+
ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator)
|
| 66 |
+
x = torch.stack([data[i : i + self.block_size] for i in ix])
|
| 67 |
+
y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
|
| 68 |
+
return x.to(self.device), y.to(self.device)
|
| 69 |
+
|
| 70 |
+
# -----------------------------
|
| 71 |
+
# Chunked Sparse Autograd
|
| 72 |
+
# -----------------------------
|
| 73 |
+
|
| 74 |
+
class ChunkedMaskedLinear(torch.autograd.Function):
|
| 75 |
+
@staticmethod
|
| 76 |
+
def forward(ctx, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], active_chunks: torch.Tensor, chunk_size: int, sparse_dx: bool) -> torch.Tensor:
|
| 77 |
+
ctx.save_for_backward(x, weight, active_chunks)
|
| 78 |
+
ctx.has_bias = bias is not None
|
| 79 |
+
ctx.sparse_dx = sparse_dx
|
| 80 |
+
ctx.chunk_size = chunk_size
|
| 81 |
+
return F.linear(x, weight, bias)
|
| 82 |
+
|
| 83 |
+
@staticmethod
|
| 84 |
+
def backward(ctx, grad_y: torch.Tensor):
|
| 85 |
+
x, weight, active_chunks = ctx.saved_tensors
|
| 86 |
+
chunk_size = ctx.chunk_size
|
| 87 |
+
|
| 88 |
+
x_flat = x.reshape(-1, x.shape[-1])
|
| 89 |
+
gy_flat = grad_y.reshape(-1, grad_y.shape[-1])
|
| 90 |
+
|
| 91 |
+
# Initialize full zero gradients
|
| 92 |
+
grad_w = torch.zeros_like(weight)
|
| 93 |
+
grad_b = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if ctx.has_bias else None
|
| 94 |
+
|
| 95 |
+
if ctx.sparse_dx:
|
| 96 |
+
grad_x_flat = torch.zeros_like(x_flat)
|
| 97 |
+
else:
|
| 98 |
+
grad_x_flat = gy_flat @ weight
|
| 99 |
+
|
| 100 |
+
# THE MAGIC: Zero-copy Strided Views
|
| 101 |
+
for c_idx in active_chunks.tolist():
|
| 102 |
+
start = c_idx * chunk_size
|
| 103 |
+
end = start + chunk_size
|
| 104 |
+
|
| 105 |
+
# These are views! No memory allocation, no gathering.
|
| 106 |
+
gy_slice = gy_flat[:, start:end]
|
| 107 |
+
w_slice = weight[start:end, :]
|
| 108 |
+
|
| 109 |
+
# Triggers dense hardware matmul (AMX / Tensor Cores) directly
|
| 110 |
+
grad_w[start:end, :] = gy_slice.t() @ x_flat
|
| 111 |
+
|
| 112 |
+
if ctx.has_bias:
|
| 113 |
+
grad_b[start:end] = gy_slice.sum(dim=0)
|
| 114 |
+
|
| 115 |
+
if ctx.sparse_dx:
|
| 116 |
+
grad_x_flat += gy_slice @ w_slice
|
| 117 |
+
|
| 118 |
+
return grad_x_flat.reshape(x.shape), grad_w, grad_b, None, None, None
|
| 119 |
+
|
| 120 |
+
class SparseLinear(nn.Linear):
|
| 121 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
| 122 |
+
super().__init__(in_features, out_features, bias=bias)
|
| 123 |
+
self.sparse_enabled = False
|
| 124 |
+
self.sparse_dx = False
|
| 125 |
+
self.active_chunks: Optional[torch.Tensor] = None
|
| 126 |
+
|
| 127 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 128 |
+
if not self.sparse_enabled or self.active_chunks is None:
|
| 129 |
+
return F.linear(x, self.weight, self.bias)
|
| 130 |
+
return ChunkedMaskedLinear.apply(x, self.weight, self.bias, self.active_chunks, getattr(self, 'chunk_size', 64), self.sparse_dx)
|
| 131 |
+
|
| 132 |
+
# -----------------------------
|
| 133 |
+
# Mini GPT
|
| 134 |
+
# -----------------------------
|
| 135 |
+
class CausalSelfAttention(nn.Module):
|
| 136 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 137 |
+
super().__init__()
|
| 138 |
+
assert n_embd % n_head == 0
|
| 139 |
+
self.n_head = n_head
|
| 140 |
+
self.head_dim = n_embd // n_head
|
| 141 |
+
self.c_attn = SparseLinear(n_embd, 3 * n_embd)
|
| 142 |
+
self.c_proj = SparseLinear(n_embd, n_embd)
|
| 143 |
+
self.dropout = nn.Dropout(dropout)
|
| 144 |
+
self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
|
| 145 |
+
|
| 146 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 147 |
+
B, T, C = x.shape
|
| 148 |
+
qkv = self.c_attn(x)
|
| 149 |
+
q, k, v = qkv.split(C, dim=2)
|
| 150 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 151 |
+
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 152 |
+
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 153 |
+
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 154 |
+
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
|
| 155 |
+
att = F.softmax(att, dim=-1)
|
| 156 |
+
att = self.dropout(att)
|
| 157 |
+
y = att @ v
|
| 158 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 159 |
+
return self.c_proj(y)
|
| 160 |
+
|
| 161 |
+
class FeedForward(nn.Module):
|
| 162 |
+
def __init__(self, n_embd: int, dropout: float):
|
| 163 |
+
super().__init__()
|
| 164 |
+
self.c_fc = SparseLinear(n_embd, 4 * n_embd)
|
| 165 |
+
self.c_proj = SparseLinear(4 * n_embd, n_embd)
|
| 166 |
+
self.dropout = nn.Dropout(dropout)
|
| 167 |
+
|
| 168 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 169 |
+
return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
|
| 170 |
+
|
| 171 |
+
class Block(nn.Module):
|
| 172 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.ln1 = nn.LayerNorm(n_embd)
|
| 175 |
+
self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
|
| 176 |
+
self.ln2 = nn.LayerNorm(n_embd)
|
| 177 |
+
self.mlp = FeedForward(n_embd, dropout)
|
| 178 |
+
|
| 179 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 180 |
+
x = x + self.attn(self.ln1(x))
|
| 181 |
+
x = x + self.mlp(self.ln2(x))
|
| 182 |
+
return x
|
| 183 |
+
|
| 184 |
+
class MiniGPT(nn.Module):
|
| 185 |
+
def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float):
|
| 186 |
+
super().__init__()
|
| 187 |
+
self.block_size = block_size
|
| 188 |
+
self.tok_emb = nn.Embedding(vocab_size, n_embd)
|
| 189 |
+
self.pos_emb = nn.Embedding(block_size, n_embd)
|
| 190 |
+
self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
|
| 191 |
+
self.ln_f = nn.LayerNorm(n_embd)
|
| 192 |
+
|
| 193 |
+
# Standard nn.Linear for the LM head so it isn't restricted by chunk sizes
|
| 194 |
+
# and correctly calculates cross-entropy probabilities.
|
| 195 |
+
self.lm_head = nn.Linear(n_embd, vocab_size)
|
| 196 |
+
|
| 197 |
+
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
|
| 198 |
+
B, T = idx.shape
|
| 199 |
+
pos = torch.arange(T, device=idx.device)
|
| 200 |
+
x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
|
| 201 |
+
x = self.blocks(x)
|
| 202 |
+
x = self.ln_f(x)
|
| 203 |
+
logits = self.lm_head(x)
|
| 204 |
+
loss = None
|
| 205 |
+
if targets is not None:
|
| 206 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
| 207 |
+
return logits, loss
|
| 208 |
+
|
| 209 |
+
def get_sparse_linears(model):
|
| 210 |
+
return[m for m in model.modules() if isinstance(m, SparseLinear)]
|
| 211 |
+
|
| 212 |
+
# -----------------------------
|
| 213 |
+
# Chunk Masker
|
| 214 |
+
# -----------------------------
|
| 215 |
+
class ChunkMasker:
|
| 216 |
+
def __init__(self, model: nn.Module, policy: Policy, active_fraction: float, chunk_size: int, device: str):
|
| 217 |
+
self.policy = policy
|
| 218 |
+
self.active_fraction = active_fraction
|
| 219 |
+
self.chunk_size = chunk_size
|
| 220 |
+
self.device = device
|
| 221 |
+
|
| 222 |
+
self.linears = get_sparse_linears(model)
|
| 223 |
+
self.module_to_chunk_ids = {}
|
| 224 |
+
offset = 0
|
| 225 |
+
for m in self.linears:
|
| 226 |
+
assert m.out_features % chunk_size == 0, f"out_features {m.out_features} not divisible by chunk size {chunk_size}"
|
| 227 |
+
n_chunks = m.out_features // chunk_size
|
| 228 |
+
self.module_to_chunk_ids[m] = torch.arange(offset, offset + n_chunks, device=device)
|
| 229 |
+
offset += n_chunks
|
| 230 |
+
|
| 231 |
+
self.n_chunks = offset
|
| 232 |
+
self.predicted_mass = torch.zeros(self.n_chunks, device=device)
|
| 233 |
+
self.active_chunks = torch.zeros(self.n_chunks, dtype=torch.bool, device=device)
|
| 234 |
+
|
| 235 |
+
def choose_active(self, step: int, warmup_steps: int):
|
| 236 |
+
if step < warmup_steps:
|
| 237 |
+
self.active_chunks.fill_(True)
|
| 238 |
+
for m, ids in self.module_to_chunk_ids.items():
|
| 239 |
+
m.active_chunks = torch.arange(len(ids), device=self.device)
|
| 240 |
+
return
|
| 241 |
+
|
| 242 |
+
k = max(1, int(self.active_fraction * self.n_chunks))
|
| 243 |
+
self.active_chunks.fill_(False)
|
| 244 |
+
|
| 245 |
+
if self.policy == "random":
|
| 246 |
+
self.active_chunks[torch.randperm(self.n_chunks, device=self.device)[:k]] = True
|
| 247 |
+
elif self.policy == "predicted_magnitude":
|
| 248 |
+
# Add tiny noise for tie-breaking
|
| 249 |
+
scores = self.predicted_mass + 1e-9 * torch.rand_like(self.predicted_mass)
|
| 250 |
+
self.active_chunks[torch.topk(scores, k=k).indices] = True
|
| 251 |
+
|
| 252 |
+
for m, ids in self.module_to_chunk_ids.items():
|
| 253 |
+
global_active = self.active_chunks[ids]
|
| 254 |
+
local_ids = torch.arange(len(ids), device=self.device)
|
| 255 |
+
m.active_chunks = local_ids[global_active]
|
| 256 |
+
|
| 257 |
+
@torch.no_grad()
|
| 258 |
+
def update_predictor(self, mass_beta=0.95):
|
| 259 |
+
# Calculate true L2 norm of the gradient for each CHUNK
|
| 260 |
+
current_mass = torch.zeros_like(self.predicted_mass)
|
| 261 |
+
for m, ids in self.module_to_chunk_ids.items():
|
| 262 |
+
if m.weight.grad is None: continue
|
| 263 |
+
# Reshape[Out, In] -> [n_chunks, chunk_size, In], square, sum across chunk_size and In
|
| 264 |
+
w_sq = m.weight.grad.square().view(len(ids), self.chunk_size, -1).sum(dim=(1, 2))
|
| 265 |
+
if m.bias is not None and m.bias.grad is not None:
|
| 266 |
+
w_sq += m.bias.grad.square().view(len(ids), self.chunk_size).sum(dim=1)
|
| 267 |
+
current_mass[ids] = torch.sqrt(w_sq + 1e-30)
|
| 268 |
+
|
| 269 |
+
# Only update observed (active) chunks
|
| 270 |
+
observed = self.active_chunks
|
| 271 |
+
self.predicted_mass[observed] = mass_beta * self.predicted_mass[observed] + (1.0 - mass_beta) * current_mass[observed]
|
| 272 |
+
|
| 273 |
+
# -----------------------------
|
| 274 |
+
# Chunked Adam
|
| 275 |
+
# -----------------------------
|
| 276 |
+
class ChunkedAdam:
|
| 277 |
+
def __init__(self, model, lr=3e-4, chunk_size=64):
|
| 278 |
+
self.model = model
|
| 279 |
+
self.lr = lr
|
| 280 |
+
self.chunk_size = chunk_size
|
| 281 |
+
self.state = {}
|
| 282 |
+
|
| 283 |
+
# Keep track of which parameters belong to sparse modules
|
| 284 |
+
self.param_to_sparse_module = {}
|
| 285 |
+
for m in get_sparse_linears(model):
|
| 286 |
+
if m.weight is not None: self.param_to_sparse_module[m.weight] = m
|
| 287 |
+
if m.bias is not None: self.param_to_sparse_module[m.bias] = m
|
| 288 |
+
|
| 289 |
+
def zero_grad(self):
|
| 290 |
+
for p in self.model.parameters(): p.grad = None
|
| 291 |
+
|
| 292 |
+
@torch.no_grad()
|
| 293 |
+
def step(self):
|
| 294 |
+
for p in self.model.parameters():
|
| 295 |
+
if p.grad is None: continue
|
| 296 |
+
if p not in self.state:
|
| 297 |
+
self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
|
| 298 |
+
|
| 299 |
+
exp_avg, exp_avg_sq = self.state[p]["m"], self.state[p]["v"]
|
| 300 |
+
|
| 301 |
+
sparse_module = self.param_to_sparse_module.get(p)
|
| 302 |
+
active_chunks = getattr(sparse_module, 'active_chunks', None) if sparse_module else None
|
| 303 |
+
|
| 304 |
+
if active_chunks is None:
|
| 305 |
+
# Dense update for embeddings, layernorms, LM head, or baseline
|
| 306 |
+
exp_avg.mul_(0.9).add_(p.grad, alpha=0.1)
|
| 307 |
+
exp_avg_sq.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
|
| 308 |
+
update = exp_avg / (torch.sqrt(exp_avg_sq) + 1e-8)
|
| 309 |
+
p.sub_(update, alpha=self.lr)
|
| 310 |
+
else:
|
| 311 |
+
# Sparse update ONLY on active chunks (indices are local per module)
|
| 312 |
+
for local_c in active_chunks.tolist():
|
| 313 |
+
start = local_c * self.chunk_size
|
| 314 |
+
end = (local_c + 1) * self.chunk_size
|
| 315 |
+
|
| 316 |
+
p_chunk = p[start:end]
|
| 317 |
+
g_chunk = p.grad[start:end]
|
| 318 |
+
m_chunk = exp_avg[start:end]
|
| 319 |
+
v_chunk = exp_avg_sq[start:end]
|
| 320 |
+
|
| 321 |
+
m_chunk.mul_(0.9).add_(g_chunk, alpha=0.1)
|
| 322 |
+
v_chunk.mul_(0.999).addcmul_(g_chunk, g_chunk, value=0.001)
|
| 323 |
+
|
| 324 |
+
update = m_chunk / (torch.sqrt(v_chunk) + 1e-8)
|
| 325 |
+
p_chunk.sub_(update, alpha=self.lr)
|
| 326 |
+
|
| 327 |
+
# -----------------------------
|
| 328 |
+
# Training
|
| 329 |
+
# -----------------------------
|
| 330 |
+
def main():
|
| 331 |
+
parser = argparse.ArgumentParser()
|
| 332 |
+
parser.add_argument("--steps", type=int, default=50)
|
| 333 |
+
parser.add_argument("--batch_size", type=int, default=16)
|
| 334 |
+
parser.add_argument("--block_size", type=int, default=256)
|
| 335 |
+
parser.add_argument("--n_layer", type=int, default=4)
|
| 336 |
+
parser.add_argument("--n_head", type=int, default=16)
|
| 337 |
+
parser.add_argument("--n_embd", type=int, default=1024)
|
| 338 |
+
parser.add_argument("--chunk_size", type=int, default=64)
|
| 339 |
+
parser.add_argument("--active_fraction", type=float, default=0.05)
|
| 340 |
+
parser.add_argument("--device", type=str, default="mps")
|
| 341 |
+
parser.add_argument("--benchmark_sync", action="store_true")
|
| 342 |
+
args = parser.parse_args()
|
| 343 |
+
|
| 344 |
+
corpus = CharCorpus(make_synthetic_corpus(), args.block_size, args.device)
|
| 345 |
+
|
| 346 |
+
modes =[
|
| 347 |
+
("dense_baseline", "dense_baseline"),
|
| 348 |
+
("predicted_magnitude", "sparse_dW_full_dX"),
|
| 349 |
+
("predicted_magnitude", "sparse_dW_sparse_dX")
|
| 350 |
+
]
|
| 351 |
+
|
| 352 |
+
print(f"\nModel: {args.n_layer} layers, {args.n_embd} d_model, {args.chunk_size} chunk_size")
|
| 353 |
+
print(f"Batch: {args.batch_size}, Block: {args.block_size}. Active Fraction: {args.active_fraction}\n")
|
| 354 |
+
print(f"{'Run':>20s} | {'Time (s)':>10s} | {'Step (ms)':>10s} | {'Val Loss':>8s}")
|
| 355 |
+
print("-" * 55)
|
| 356 |
+
|
| 357 |
+
for policy, bwd_mode in modes:
|
| 358 |
+
set_seed(42)
|
| 359 |
+
model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, 0.0).to(args.device)
|
| 360 |
+
|
| 361 |
+
for m in get_sparse_linears(model):
|
| 362 |
+
m.chunk_size = args.chunk_size
|
| 363 |
+
|
| 364 |
+
masker = ChunkMasker(model, policy, args.active_fraction, args.chunk_size, args.device) if policy != "dense_baseline" else None
|
| 365 |
+
opt = ChunkedAdam(model, chunk_size=args.chunk_size)
|
| 366 |
+
|
| 367 |
+
if args.benchmark_sync: sync_device(args.device)
|
| 368 |
+
t0 = time.perf_counter()
|
| 369 |
+
|
| 370 |
+
for step in range(args.steps):
|
| 371 |
+
x, y = corpus.get_batch("train", args.batch_size, generator=make_cpu_generator(step))
|
| 372 |
+
|
| 373 |
+
if masker:
|
| 374 |
+
masker.choose_active(step, warmup_steps=5)
|
| 375 |
+
for m in get_sparse_linears(model):
|
| 376 |
+
m.sparse_enabled = True
|
| 377 |
+
m.sparse_dx = (bwd_mode == "sparse_dW_sparse_dX")
|
| 378 |
+
else:
|
| 379 |
+
for m in get_sparse_linears(model):
|
| 380 |
+
m.sparse_enabled = False
|
| 381 |
+
m.active_chunks = None
|
| 382 |
+
|
| 383 |
+
opt.zero_grad()
|
| 384 |
+
_, loss = model(x, y)
|
| 385 |
+
loss.backward()
|
| 386 |
+
|
| 387 |
+
if masker:
|
| 388 |
+
masker.update_predictor()
|
| 389 |
+
|
| 390 |
+
opt.step()
|
| 391 |
+
|
| 392 |
+
if args.benchmark_sync: sync_device(args.device)
|
| 393 |
+
t_elapsed = time.perf_counter() - t0
|
| 394 |
+
|
| 395 |
+
# Eval loss
|
| 396 |
+
model.eval()
|
| 397 |
+
with torch.no_grad():
|
| 398 |
+
x, y = corpus.get_batch("val", args.batch_size, generator=make_cpu_generator(999))
|
| 399 |
+
_, val_loss = model(x, y)
|
| 400 |
+
|
| 401 |
+
# Format the mode strictly for the printout width
|
| 402 |
+
bwd_str = bwd_mode if bwd_mode == "dense_baseline" else ("sparse_full_dX" if "full_dX" in bwd_mode else "sparse_sparse_dX")
|
| 403 |
+
print(f"{bwd_str:>20s} | {t_elapsed:10.2f} | {1000*t_elapsed/args.steps:10.2f} | {val_loss.item():8.4f}")
|
| 404 |
+
|
| 405 |
+
if __name__ == "__main__":
|
| 406 |
+
main()
|
experiments/sparse_linear_v11_gather_vs_metal/sparse_transformer_v13.py
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sparse Transformer v14: The Final Architecture.
|
| 3 |
+
|
| 4 |
+
Combines the Hardware-Sympathetic Chunked Sparse backward pass with Cosine Annealing,
|
| 5 |
+
but restores the Chunked Optimizer to prevent Dense memory-bandwidth bottlenecks.
|
| 6 |
+
Benchmarks are isolated to the steady-state phase (after annealing) for accurate timings.
|
| 7 |
+
|
| 8 |
+
Run:
|
| 9 |
+
python3 sparse_transformer_v14.py --device mps --benchmark_sync --n_embd 1024
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import math
|
| 14 |
+
import random
|
| 15 |
+
import time
|
| 16 |
+
from typing import Dict, List, Literal, Optional, Tuple
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
torch.set_num_threads(1)
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
|
| 24 |
+
def sync_device(device: str) -> None:
|
| 25 |
+
if device == "cuda" and torch.cuda.is_available():
|
| 26 |
+
torch.cuda.synchronize()
|
| 27 |
+
elif device == "mps" and hasattr(torch, "mps"):
|
| 28 |
+
torch.mps.synchronize()
|
| 29 |
+
|
| 30 |
+
Policy = Literal["predicted_magnitude", "oracle_current", "random"]
|
| 31 |
+
BackwardMode = Literal["dense_baseline", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
|
| 32 |
+
|
| 33 |
+
def set_seed(seed: int) -> None:
|
| 34 |
+
random.seed(seed)
|
| 35 |
+
torch.manual_seed(seed)
|
| 36 |
+
|
| 37 |
+
def make_cpu_generator(seed: int) -> torch.Generator:
|
| 38 |
+
gen = torch.Generator(device="cpu")
|
| 39 |
+
gen.manual_seed(seed)
|
| 40 |
+
return gen
|
| 41 |
+
|
| 42 |
+
# -----------------------------
|
| 43 |
+
# Data
|
| 44 |
+
# -----------------------------
|
| 45 |
+
def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
|
| 46 |
+
rng = random.Random(seed)
|
| 47 |
+
words =["ada", "turing", "grace", "lovelace", "gradients", "tokens", "circuits", "features", "boldly", "strangely"]
|
| 48 |
+
return "\n".join(" ".join(rng.choices(words, k=rng.randint(4, 10))) + "." for _ in range(n_sentences))
|
| 49 |
+
|
| 50 |
+
class CharCorpus:
|
| 51 |
+
def __init__(self, text: str, block_size: int, device: str):
|
| 52 |
+
chars = sorted(set(text))
|
| 53 |
+
self.stoi = {ch: i for i, ch in enumerate(chars)}
|
| 54 |
+
self.itos = {i: ch for ch, i in self.stoi.items()}
|
| 55 |
+
self.vocab_size = len(chars)
|
| 56 |
+
self.block_size = block_size
|
| 57 |
+
self.device = device
|
| 58 |
+
data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
|
| 59 |
+
self.train_data = data[:int(0.9 * len(data))]
|
| 60 |
+
self.val_data = data[int(0.9 * len(data)):]
|
| 61 |
+
|
| 62 |
+
def get_batch(self, split: str, batch_size: int, generator: Optional[torch.Generator] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 63 |
+
data = self.train_data if split == "train" else self.val_data
|
| 64 |
+
ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator)
|
| 65 |
+
x = torch.stack([data[i : i + self.block_size] for i in ix])
|
| 66 |
+
y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
|
| 67 |
+
return x.to(self.device), y.to(self.device)
|
| 68 |
+
|
| 69 |
+
# -----------------------------
|
| 70 |
+
# Chunked Sparse Autograd
|
| 71 |
+
# -----------------------------
|
| 72 |
+
|
| 73 |
+
class ChunkedMaskedLinear(torch.autograd.Function):
|
| 74 |
+
@staticmethod
|
| 75 |
+
def forward(ctx, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], active_chunks: torch.Tensor, chunk_size: int, sparse_dx: bool) -> torch.Tensor:
|
| 76 |
+
ctx.save_for_backward(x, weight, active_chunks)
|
| 77 |
+
ctx.has_bias = bias is not None
|
| 78 |
+
ctx.sparse_dx = sparse_dx
|
| 79 |
+
ctx.chunk_size = chunk_size
|
| 80 |
+
return F.linear(x, weight, bias)
|
| 81 |
+
|
| 82 |
+
@staticmethod
|
| 83 |
+
def backward(ctx, grad_y: torch.Tensor):
|
| 84 |
+
x, weight, active_chunks = ctx.saved_tensors
|
| 85 |
+
chunk_size = ctx.chunk_size
|
| 86 |
+
|
| 87 |
+
x_flat = x.reshape(-1, x.shape[-1])
|
| 88 |
+
gy_flat = grad_y.reshape(-1, grad_y.shape[-1])
|
| 89 |
+
|
| 90 |
+
grad_w = torch.zeros_like(weight)
|
| 91 |
+
grad_b = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if ctx.has_bias else None
|
| 92 |
+
|
| 93 |
+
if ctx.sparse_dx:
|
| 94 |
+
grad_x_flat = torch.zeros_like(x_flat)
|
| 95 |
+
else:
|
| 96 |
+
grad_x_flat = gy_flat @ weight
|
| 97 |
+
|
| 98 |
+
# Zero-copy Strided Views
|
| 99 |
+
for c_idx in active_chunks.tolist():
|
| 100 |
+
start = c_idx * chunk_size
|
| 101 |
+
end = start + chunk_size
|
| 102 |
+
|
| 103 |
+
gy_slice = gy_flat[:, start:end]
|
| 104 |
+
w_slice = weight[start:end, :]
|
| 105 |
+
|
| 106 |
+
grad_w[start:end, :] = gy_slice.t() @ x_flat
|
| 107 |
+
|
| 108 |
+
if ctx.has_bias:
|
| 109 |
+
grad_b[start:end] = gy_slice.sum(dim=0)
|
| 110 |
+
|
| 111 |
+
if ctx.sparse_dx:
|
| 112 |
+
grad_x_flat += gy_slice @ w_slice
|
| 113 |
+
|
| 114 |
+
return grad_x_flat.reshape(x.shape), grad_w, grad_b, None, None, None
|
| 115 |
+
|
| 116 |
+
class SparseLinear(nn.Linear):
|
| 117 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
| 118 |
+
super().__init__(in_features, out_features, bias=bias)
|
| 119 |
+
self.sparse_enabled = False
|
| 120 |
+
self.sparse_dx = False
|
| 121 |
+
self.active_chunks: Optional[torch.Tensor] = None
|
| 122 |
+
|
| 123 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 124 |
+
if not self.sparse_enabled or self.active_chunks is None:
|
| 125 |
+
return F.linear(x, self.weight, self.bias)
|
| 126 |
+
return ChunkedMaskedLinear.apply(x, self.weight, self.bias, self.active_chunks, getattr(self, 'chunk_size', 64), self.sparse_dx)
|
| 127 |
+
|
| 128 |
+
# -----------------------------
|
| 129 |
+
# Mini GPT
|
| 130 |
+
# -----------------------------
|
| 131 |
+
class CausalSelfAttention(nn.Module):
|
| 132 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 133 |
+
super().__init__()
|
| 134 |
+
assert n_embd % n_head == 0
|
| 135 |
+
self.n_head = n_head
|
| 136 |
+
self.head_dim = n_embd // n_head
|
| 137 |
+
self.c_attn = SparseLinear(n_embd, 3 * n_embd)
|
| 138 |
+
self.c_proj = SparseLinear(n_embd, n_embd)
|
| 139 |
+
self.dropout = nn.Dropout(dropout)
|
| 140 |
+
self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
|
| 141 |
+
|
| 142 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 143 |
+
B, T, C = x.shape
|
| 144 |
+
qkv = self.c_attn(x)
|
| 145 |
+
q, k, v = qkv.split(C, dim=2)
|
| 146 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 147 |
+
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 148 |
+
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 149 |
+
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 150 |
+
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
|
| 151 |
+
att = F.softmax(att, dim=-1)
|
| 152 |
+
att = self.dropout(att)
|
| 153 |
+
y = att @ v
|
| 154 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 155 |
+
return self.c_proj(y)
|
| 156 |
+
|
| 157 |
+
class FeedForward(nn.Module):
|
| 158 |
+
def __init__(self, n_embd: int, dropout: float):
|
| 159 |
+
super().__init__()
|
| 160 |
+
self.c_fc = SparseLinear(n_embd, 4 * n_embd)
|
| 161 |
+
self.c_proj = SparseLinear(4 * n_embd, n_embd)
|
| 162 |
+
self.dropout = nn.Dropout(dropout)
|
| 163 |
+
|
| 164 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 165 |
+
return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
|
| 166 |
+
|
| 167 |
+
class Block(nn.Module):
|
| 168 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 169 |
+
super().__init__()
|
| 170 |
+
self.ln1 = nn.LayerNorm(n_embd)
|
| 171 |
+
self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
|
| 172 |
+
self.ln2 = nn.LayerNorm(n_embd)
|
| 173 |
+
self.mlp = FeedForward(n_embd, dropout)
|
| 174 |
+
|
| 175 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 176 |
+
x = x + self.attn(self.ln1(x))
|
| 177 |
+
x = x + self.mlp(self.ln2(x))
|
| 178 |
+
return x
|
| 179 |
+
|
| 180 |
+
class MiniGPT(nn.Module):
|
| 181 |
+
def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float):
|
| 182 |
+
super().__init__()
|
| 183 |
+
self.block_size = block_size
|
| 184 |
+
self.tok_emb = nn.Embedding(vocab_size, n_embd)
|
| 185 |
+
self.pos_emb = nn.Embedding(block_size, n_embd)
|
| 186 |
+
self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
|
| 187 |
+
self.ln_f = nn.LayerNorm(n_embd)
|
| 188 |
+
self.lm_head = nn.Linear(n_embd, vocab_size)
|
| 189 |
+
|
| 190 |
+
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
|
| 191 |
+
B, T = idx.shape
|
| 192 |
+
pos = torch.arange(T, device=idx.device)
|
| 193 |
+
x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
|
| 194 |
+
x = self.blocks(x)
|
| 195 |
+
x = self.ln_f(x)
|
| 196 |
+
logits = self.lm_head(x)
|
| 197 |
+
loss = None
|
| 198 |
+
if targets is not None:
|
| 199 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
| 200 |
+
return logits, loss
|
| 201 |
+
|
| 202 |
+
def get_sparse_linears(model):
|
| 203 |
+
return[m for m in model.modules() if isinstance(m, SparseLinear)]
|
| 204 |
+
|
| 205 |
+
# -----------------------------
|
| 206 |
+
# Chunk Masker with Annealing
|
| 207 |
+
# -----------------------------
|
| 208 |
+
class ChunkMasker:
|
| 209 |
+
def __init__(self, model: nn.Module, policy: Policy, target_fraction: float, chunk_size: int, device: str):
|
| 210 |
+
self.policy = policy
|
| 211 |
+
self.target_fraction = target_fraction
|
| 212 |
+
self.chunk_size = chunk_size
|
| 213 |
+
self.device = device
|
| 214 |
+
|
| 215 |
+
self.linears = get_sparse_linears(model)
|
| 216 |
+
self.module_to_chunk_ids = {}
|
| 217 |
+
offset = 0
|
| 218 |
+
for m in self.linears:
|
| 219 |
+
assert m.out_features % chunk_size == 0, f"out_features {m.out_features} not divisible by chunk size {chunk_size}"
|
| 220 |
+
n_chunks = m.out_features // chunk_size
|
| 221 |
+
self.module_to_chunk_ids[m] = torch.arange(offset, offset + n_chunks, device=device)
|
| 222 |
+
offset += n_chunks
|
| 223 |
+
|
| 224 |
+
self.n_chunks = offset
|
| 225 |
+
self.predicted_mass = torch.zeros(self.n_chunks, device=device)
|
| 226 |
+
self.active_chunks = torch.zeros(self.n_chunks, dtype=torch.bool, device=device)
|
| 227 |
+
|
| 228 |
+
def choose_active(self, step: int, warmup_steps: int, anneal_steps: int):
|
| 229 |
+
# Cosine Annealing Logic
|
| 230 |
+
if step < warmup_steps:
|
| 231 |
+
current_fraction = 1.0
|
| 232 |
+
elif step < warmup_steps + anneal_steps:
|
| 233 |
+
progress = (step - warmup_steps) / anneal_steps
|
| 234 |
+
cosine_mult = 0.5 * (1.0 + math.cos(math.pi * progress))
|
| 235 |
+
current_fraction = self.target_fraction + (1.0 - self.target_fraction) * cosine_mult
|
| 236 |
+
else:
|
| 237 |
+
current_fraction = self.target_fraction
|
| 238 |
+
|
| 239 |
+
if current_fraction >= 0.999:
|
| 240 |
+
self.active_chunks.fill_(True)
|
| 241 |
+
for m, ids in self.module_to_chunk_ids.items():
|
| 242 |
+
m.active_chunks = torch.arange(len(ids), device=self.device)
|
| 243 |
+
return
|
| 244 |
+
|
| 245 |
+
k = max(1, int(current_fraction * self.n_chunks))
|
| 246 |
+
self.active_chunks.fill_(False)
|
| 247 |
+
|
| 248 |
+
if self.policy == "random":
|
| 249 |
+
self.active_chunks[torch.randperm(self.n_chunks, device=self.device)[:k]] = True
|
| 250 |
+
elif self.policy == "predicted_magnitude":
|
| 251 |
+
scores = self.predicted_mass + 1e-9 * torch.rand_like(self.predicted_mass)
|
| 252 |
+
self.active_chunks[torch.topk(scores, k=k).indices] = True
|
| 253 |
+
|
| 254 |
+
for m, ids in self.module_to_chunk_ids.items():
|
| 255 |
+
global_active = self.active_chunks[ids]
|
| 256 |
+
local_ids = torch.arange(len(ids), device=self.device)
|
| 257 |
+
m.active_chunks = local_ids[global_active]
|
| 258 |
+
|
| 259 |
+
@torch.no_grad()
|
| 260 |
+
def update_predictor(self, mass_beta=0.95):
|
| 261 |
+
current_mass = torch.zeros_like(self.predicted_mass)
|
| 262 |
+
for m, ids in self.module_to_chunk_ids.items():
|
| 263 |
+
if m.weight.grad is None: continue
|
| 264 |
+
w_sq = m.weight.grad.square().view(len(ids), self.chunk_size, -1).sum(dim=(1, 2))
|
| 265 |
+
if m.bias is not None and m.bias.grad is not None:
|
| 266 |
+
w_sq += m.bias.grad.square().view(len(ids), self.chunk_size).sum(dim=1)
|
| 267 |
+
current_mass[ids] = torch.sqrt(w_sq + 1e-30)
|
| 268 |
+
|
| 269 |
+
observed = self.active_chunks
|
| 270 |
+
self.predicted_mass[observed] = mass_beta * self.predicted_mass[observed] + (1.0 - mass_beta) * current_mass[observed]
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
# -----------------------------
|
| 274 |
+
# Chunked Adam (Restored)
|
| 275 |
+
# -----------------------------
|
| 276 |
+
class ChunkedAdam:
|
| 277 |
+
def __init__(self, model, lr=3e-4, chunk_size=64):
|
| 278 |
+
self.model = model
|
| 279 |
+
self.lr = lr
|
| 280 |
+
self.chunk_size = chunk_size
|
| 281 |
+
self.state = {}
|
| 282 |
+
|
| 283 |
+
# Keep track of which parameters belong to sparse modules
|
| 284 |
+
self.param_to_sparse_module = {}
|
| 285 |
+
for m in get_sparse_linears(model):
|
| 286 |
+
if m.weight is not None: self.param_to_sparse_module[m.weight] = m
|
| 287 |
+
if m.bias is not None: self.param_to_sparse_module[m.bias] = m
|
| 288 |
+
|
| 289 |
+
def zero_grad(self):
|
| 290 |
+
for p in self.model.parameters(): p.grad = None
|
| 291 |
+
|
| 292 |
+
@torch.no_grad()
|
| 293 |
+
def step(self):
|
| 294 |
+
for p in self.model.parameters():
|
| 295 |
+
if p.grad is None: continue
|
| 296 |
+
if p not in self.state:
|
| 297 |
+
self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
|
| 298 |
+
|
| 299 |
+
exp_avg, exp_avg_sq = self.state[p]["m"], self.state[p]["v"]
|
| 300 |
+
|
| 301 |
+
sparse_module = self.param_to_sparse_module.get(p)
|
| 302 |
+
active_chunks = getattr(sparse_module, 'active_chunks', None) if sparse_module else None
|
| 303 |
+
|
| 304 |
+
if active_chunks is None:
|
| 305 |
+
# Dense update for embeddings, layernorms, LM head, or baseline
|
| 306 |
+
exp_avg.mul_(0.9).add_(p.grad, alpha=0.1)
|
| 307 |
+
exp_avg_sq.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
|
| 308 |
+
update = exp_avg / (torch.sqrt(exp_avg_sq) + 1e-8)
|
| 309 |
+
p.sub_(update, alpha=self.lr)
|
| 310 |
+
else:
|
| 311 |
+
# Sparse update ONLY on the active chunks
|
| 312 |
+
for local_c in active_chunks.tolist():
|
| 313 |
+
start = local_c * self.chunk_size
|
| 314 |
+
end = (local_c + 1) * self.chunk_size
|
| 315 |
+
|
| 316 |
+
p_chunk = p[start:end]
|
| 317 |
+
g_chunk = p.grad[start:end]
|
| 318 |
+
m_chunk = exp_avg[start:end]
|
| 319 |
+
v_chunk = exp_avg_sq[start:end]
|
| 320 |
+
|
| 321 |
+
m_chunk.mul_(0.9).add_(g_chunk, alpha=0.1)
|
| 322 |
+
v_chunk.mul_(0.999).addcmul_(g_chunk, g_chunk, value=0.001)
|
| 323 |
+
|
| 324 |
+
update = m_chunk / (torch.sqrt(v_chunk) + 1e-8)
|
| 325 |
+
p_chunk.sub_(update, alpha=self.lr)
|
| 326 |
+
|
| 327 |
+
# -----------------------------
|
| 328 |
+
# Training
|
| 329 |
+
# -----------------------------
|
| 330 |
+
def main():
|
| 331 |
+
parser = argparse.ArgumentParser()
|
| 332 |
+
parser.add_argument("--steps", type=int, default=500)
|
| 333 |
+
parser.add_argument("--batch_size", type=int, default=16)
|
| 334 |
+
parser.add_argument("--block_size", type=int, default=256)
|
| 335 |
+
parser.add_argument("--n_layer", type=int, default=4)
|
| 336 |
+
parser.add_argument("--n_head", type=int, default=16)
|
| 337 |
+
parser.add_argument("--n_embd", type=int, default=1024)
|
| 338 |
+
parser.add_argument("--chunk_size", type=int, default=64)
|
| 339 |
+
parser.add_argument("--active_fraction", type=float, default=0.05)
|
| 340 |
+
parser.add_argument("--warmup_steps", type=int, default=10)
|
| 341 |
+
parser.add_argument("--anneal_steps", type=int, default=150)
|
| 342 |
+
parser.add_argument("--device", type=str, default="mps")
|
| 343 |
+
parser.add_argument("--benchmark_sync", action="store_true")
|
| 344 |
+
args = parser.parse_args()
|
| 345 |
+
|
| 346 |
+
corpus = CharCorpus(make_synthetic_corpus(), args.block_size, args.device)
|
| 347 |
+
|
| 348 |
+
modes =[
|
| 349 |
+
("dense_baseline", "dense_baseline"),
|
| 350 |
+
("predicted_magnitude", "sparse_dW_full_dX"),
|
| 351 |
+
("predicted_magnitude", "sparse_dW_sparse_dX")
|
| 352 |
+
]
|
| 353 |
+
|
| 354 |
+
print(f"\nModel: {args.n_layer} layers, {args.n_embd} d_model, {args.chunk_size} chunk_size")
|
| 355 |
+
print(f"Batch: {args.batch_size}, Block: {args.block_size}. Target Active Fraction: {args.active_fraction}")
|
| 356 |
+
print(f"Annealing: {args.warmup_steps} warmup steps, {args.anneal_steps} anneal steps.\n")
|
| 357 |
+
print(f"{'Run':>20s} | {'Time (s)':>10s} | {'Step (ms)':>10s} | {'Val Loss':>8s}")
|
| 358 |
+
print("-" * 55)
|
| 359 |
+
|
| 360 |
+
for policy, bwd_mode in modes:
|
| 361 |
+
set_seed(42)
|
| 362 |
+
model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, 0.0).to(args.device)
|
| 363 |
+
|
| 364 |
+
for m in get_sparse_linears(model):
|
| 365 |
+
m.chunk_size = args.chunk_size
|
| 366 |
+
|
| 367 |
+
masker = ChunkMasker(model, policy, args.active_fraction, args.chunk_size, args.device) if policy != "dense_baseline" else None
|
| 368 |
+
|
| 369 |
+
# Restoring the sparse optimizer!
|
| 370 |
+
opt = ChunkedAdam(model, chunk_size=args.chunk_size)
|
| 371 |
+
|
| 372 |
+
if args.benchmark_sync: sync_device(args.device)
|
| 373 |
+
|
| 374 |
+
# We will only measure the time AFTER annealing finishes to get the true steady-state sparse speed.
|
| 375 |
+
t0 = time.perf_counter()
|
| 376 |
+
measured_steps = args.steps
|
| 377 |
+
|
| 378 |
+
for step in range(args.steps):
|
| 379 |
+
# Reset the timer once we hit the target sparsity
|
| 380 |
+
if step == args.warmup_steps + args.anneal_steps:
|
| 381 |
+
if args.benchmark_sync: sync_device(args.device)
|
| 382 |
+
t0 = time.perf_counter()
|
| 383 |
+
measured_steps = args.steps - step
|
| 384 |
+
|
| 385 |
+
x, y = corpus.get_batch("train", args.batch_size, generator=make_cpu_generator(step))
|
| 386 |
+
|
| 387 |
+
if masker:
|
| 388 |
+
masker.choose_active(step, warmup_steps=args.warmup_steps, anneal_steps=args.anneal_steps)
|
| 389 |
+
for m in get_sparse_linears(model):
|
| 390 |
+
m.sparse_enabled = True
|
| 391 |
+
m.sparse_dx = (bwd_mode == "sparse_dW_sparse_dX")
|
| 392 |
+
else:
|
| 393 |
+
for m in get_sparse_linears(model):
|
| 394 |
+
m.sparse_enabled = False
|
| 395 |
+
m.active_chunks = None
|
| 396 |
+
|
| 397 |
+
opt.zero_grad()
|
| 398 |
+
_, loss = model(x, y)
|
| 399 |
+
loss.backward()
|
| 400 |
+
|
| 401 |
+
if masker:
|
| 402 |
+
masker.update_predictor()
|
| 403 |
+
|
| 404 |
+
opt.step()
|
| 405 |
+
|
| 406 |
+
if args.benchmark_sync: sync_device(args.device)
|
| 407 |
+
t_elapsed = time.perf_counter() - t0
|
| 408 |
+
|
| 409 |
+
# Eval loss
|
| 410 |
+
model.eval()
|
| 411 |
+
with torch.no_grad():
|
| 412 |
+
x, y = corpus.get_batch("val", args.batch_size, generator=make_cpu_generator(999))
|
| 413 |
+
_, val_loss = model(x, y)
|
| 414 |
+
|
| 415 |
+
bwd_str = bwd_mode if bwd_mode == "dense_baseline" else ("sparse_full_dX" if "full_dX" in bwd_mode else "sparse_sparse_dX")
|
| 416 |
+
print(f"{bwd_str:>20s} | {t_elapsed:10.2f} | {1000*t_elapsed/max(1, measured_steps):10.2f} | {val_loss.item():8.4f}")
|
| 417 |
+
|
| 418 |
+
if __name__ == "__main__":
|
| 419 |
+
main()
|
experiments/sparse_linear_v11_gather_vs_metal/tiny.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sparse Transformer: Real-World Benchmark on Tiny Shakespeare using GPT-2 BPE.
|
| 3 |
+
|
| 4 |
+
This script scales the architecture to a 6-layer, 512-dim GPT and trains on
|
| 5 |
+
real natural language. It applies our Hardware-Sympathetic Chunked Sparse
|
| 6 |
+
backward pass, Cosine Annealing, and Chunked Adam optimizer.
|
| 7 |
+
|
| 8 |
+
Run:
|
| 9 |
+
python3 sparse_transformer_shakespeare.py --device mps --benchmark_sync
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import math
|
| 14 |
+
import os
|
| 15 |
+
import random
|
| 16 |
+
import time
|
| 17 |
+
import urllib.request
|
| 18 |
+
from typing import Dict, List, Literal, Optional, Tuple
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
import tiktoken
|
| 26 |
+
except ImportError:
|
| 27 |
+
raise ImportError("Please install tiktoken: pip install tiktoken")
|
| 28 |
+
|
| 29 |
+
torch.set_num_threads(1)
|
| 30 |
+
|
| 31 |
+
def sync_device(device: str) -> None:
|
| 32 |
+
if device == "cuda" and torch.cuda.is_available():
|
| 33 |
+
torch.cuda.synchronize()
|
| 34 |
+
elif device == "mps" and hasattr(torch, "mps"):
|
| 35 |
+
torch.mps.synchronize()
|
| 36 |
+
|
| 37 |
+
Policy = Literal["predicted_magnitude", "oracle_current", "random"]
|
| 38 |
+
BackwardMode = Literal["dense_baseline", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
|
| 39 |
+
|
| 40 |
+
def set_seed(seed: int) -> None:
|
| 41 |
+
random.seed(seed)
|
| 42 |
+
torch.manual_seed(seed)
|
| 43 |
+
|
| 44 |
+
def make_cpu_generator(seed: int) -> torch.Generator:
|
| 45 |
+
gen = torch.Generator(device="cpu")
|
| 46 |
+
gen.manual_seed(seed)
|
| 47 |
+
return gen
|
| 48 |
+
|
| 49 |
+
# -----------------------------
|
| 50 |
+
# Real-World Data Pipeline
|
| 51 |
+
# -----------------------------
|
| 52 |
+
class ShakespeareCorpus:
|
| 53 |
+
def __init__(self, block_size: int, device: str):
|
| 54 |
+
self.block_size = block_size
|
| 55 |
+
self.device = device
|
| 56 |
+
|
| 57 |
+
# 1. Download Tiny Shakespeare if not exists
|
| 58 |
+
data_path = "input.txt"
|
| 59 |
+
if not os.path.exists(data_path):
|
| 60 |
+
print("Downloading Tiny Shakespeare...")
|
| 61 |
+
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
|
| 62 |
+
urllib.request.urlretrieve(url, data_path)
|
| 63 |
+
|
| 64 |
+
# 2. Tokenize using GPT-2 BPE
|
| 65 |
+
print("Tokenizing data...")
|
| 66 |
+
with open(data_path, "r", encoding="utf-8") as f:
|
| 67 |
+
text = f.read()
|
| 68 |
+
|
| 69 |
+
enc = tiktoken.get_encoding("gpt2")
|
| 70 |
+
tokens = enc.encode(text)
|
| 71 |
+
self.vocab_size = enc.n_vocab
|
| 72 |
+
|
| 73 |
+
# 3. Split 90/10 Train/Val
|
| 74 |
+
data = torch.tensor(tokens, dtype=torch.long)
|
| 75 |
+
split_idx = int(0.9 * len(data))
|
| 76 |
+
self.train_data = data[:split_idx]
|
| 77 |
+
self.val_data = data[split_idx:]
|
| 78 |
+
|
| 79 |
+
print(f"Dataset loaded. Vocab size: {self.vocab_size:,}. Train tokens: {len(self.train_data):,}")
|
| 80 |
+
|
| 81 |
+
def get_batch(self, split: str, batch_size: int, generator: Optional[torch.Generator] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 82 |
+
data = self.train_data if split == "train" else self.val_data
|
| 83 |
+
ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator)
|
| 84 |
+
x = torch.stack([data[i : i + self.block_size] for i in ix])
|
| 85 |
+
y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
|
| 86 |
+
return x.to(self.device), y.to(self.device)
|
| 87 |
+
|
| 88 |
+
# -----------------------------
|
| 89 |
+
# Chunked Sparse Autograd
|
| 90 |
+
# -----------------------------
|
| 91 |
+
class ChunkedMaskedLinear(torch.autograd.Function):
|
| 92 |
+
@staticmethod
|
| 93 |
+
def forward(ctx, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], active_chunks: torch.Tensor, chunk_size: int, sparse_dx: bool) -> torch.Tensor:
|
| 94 |
+
ctx.save_for_backward(x, weight, active_chunks)
|
| 95 |
+
ctx.has_bias = bias is not None
|
| 96 |
+
ctx.sparse_dx = sparse_dx
|
| 97 |
+
ctx.chunk_size = chunk_size
|
| 98 |
+
return F.linear(x, weight, bias)
|
| 99 |
+
|
| 100 |
+
@staticmethod
|
| 101 |
+
def backward(ctx, grad_y: torch.Tensor):
|
| 102 |
+
x, weight, active_chunks = ctx.saved_tensors
|
| 103 |
+
chunk_size = ctx.chunk_size
|
| 104 |
+
|
| 105 |
+
x_flat = x.reshape(-1, x.shape[-1])
|
| 106 |
+
gy_flat = grad_y.reshape(-1, grad_y.shape[-1])
|
| 107 |
+
|
| 108 |
+
grad_w = torch.zeros_like(weight)
|
| 109 |
+
grad_b = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if ctx.has_bias else None
|
| 110 |
+
|
| 111 |
+
if ctx.sparse_dx:
|
| 112 |
+
grad_x_flat = torch.zeros_like(x_flat)
|
| 113 |
+
else:
|
| 114 |
+
grad_x_flat = gy_flat @ weight
|
| 115 |
+
|
| 116 |
+
# Zero-copy Strided Views feeding directly into Dense Hardware Matmuls
|
| 117 |
+
for c_idx in active_chunks.tolist():
|
| 118 |
+
start = c_idx * chunk_size
|
| 119 |
+
end = start + chunk_size
|
| 120 |
+
|
| 121 |
+
gy_slice = gy_flat[:, start:end]
|
| 122 |
+
w_slice = weight[start:end, :]
|
| 123 |
+
|
| 124 |
+
grad_w[start:end, :] = gy_slice.t() @ x_flat
|
| 125 |
+
|
| 126 |
+
if ctx.has_bias:
|
| 127 |
+
grad_b[start:end] = gy_slice.sum(dim=0)
|
| 128 |
+
|
| 129 |
+
if ctx.sparse_dx:
|
| 130 |
+
grad_x_flat += gy_slice @ w_slice
|
| 131 |
+
|
| 132 |
+
return grad_x_flat.reshape(x.shape), grad_w, grad_b, None, None, None
|
| 133 |
+
|
| 134 |
+
class SparseLinear(nn.Linear):
|
| 135 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
| 136 |
+
super().__init__(in_features, out_features, bias=bias)
|
| 137 |
+
self.sparse_enabled = False
|
| 138 |
+
self.sparse_dx = False
|
| 139 |
+
self.active_chunks: Optional[torch.Tensor] = None
|
| 140 |
+
|
| 141 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 142 |
+
if not self.sparse_enabled or self.active_chunks is None:
|
| 143 |
+
return F.linear(x, self.weight, self.bias)
|
| 144 |
+
return ChunkedMaskedLinear.apply(x, self.weight, self.bias, self.active_chunks, getattr(self, 'chunk_size', 64), self.sparse_dx)
|
| 145 |
+
|
| 146 |
+
# -----------------------------
|
| 147 |
+
# GPT Architecture
|
| 148 |
+
# -----------------------------
|
| 149 |
+
class CausalSelfAttention(nn.Module):
|
| 150 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 151 |
+
super().__init__()
|
| 152 |
+
assert n_embd % n_head == 0
|
| 153 |
+
self.n_head = n_head
|
| 154 |
+
self.head_dim = n_embd // n_head
|
| 155 |
+
self.c_attn = SparseLinear(n_embd, 3 * n_embd)
|
| 156 |
+
self.c_proj = SparseLinear(n_embd, n_embd)
|
| 157 |
+
self.dropout = nn.Dropout(dropout)
|
| 158 |
+
self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
|
| 159 |
+
|
| 160 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 161 |
+
B, T, C = x.shape
|
| 162 |
+
qkv = self.c_attn(x)
|
| 163 |
+
q, k, v = qkv.split(C, dim=2)
|
| 164 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 165 |
+
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 166 |
+
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 167 |
+
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 168 |
+
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
|
| 169 |
+
att = F.softmax(att, dim=-1)
|
| 170 |
+
att = self.dropout(att)
|
| 171 |
+
y = att @ v
|
| 172 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 173 |
+
return self.c_proj(y)
|
| 174 |
+
|
| 175 |
+
class FeedForward(nn.Module):
|
| 176 |
+
def __init__(self, n_embd: int, dropout: float):
|
| 177 |
+
super().__init__()
|
| 178 |
+
self.c_fc = SparseLinear(n_embd, 4 * n_embd)
|
| 179 |
+
self.c_proj = SparseLinear(4 * n_embd, n_embd)
|
| 180 |
+
self.dropout = nn.Dropout(dropout)
|
| 181 |
+
|
| 182 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 183 |
+
return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
|
| 184 |
+
|
| 185 |
+
class Block(nn.Module):
|
| 186 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 187 |
+
super().__init__()
|
| 188 |
+
self.ln1 = nn.LayerNorm(n_embd)
|
| 189 |
+
self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
|
| 190 |
+
self.ln2 = nn.LayerNorm(n_embd)
|
| 191 |
+
self.mlp = FeedForward(n_embd, dropout)
|
| 192 |
+
|
| 193 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 194 |
+
x = x + self.attn(self.ln1(x))
|
| 195 |
+
x = x + self.mlp(self.ln2(x))
|
| 196 |
+
return x
|
| 197 |
+
|
| 198 |
+
class GPT(nn.Module):
|
| 199 |
+
def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float):
|
| 200 |
+
super().__init__()
|
| 201 |
+
self.block_size = block_size
|
| 202 |
+
self.tok_emb = nn.Embedding(vocab_size, n_embd)
|
| 203 |
+
self.pos_emb = nn.Embedding(block_size, n_embd)
|
| 204 |
+
self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
|
| 205 |
+
self.ln_f = nn.LayerNorm(n_embd)
|
| 206 |
+
# LM head is Dense! Needs full output dist for CrossEntropy loss
|
| 207 |
+
self.lm_head = nn.Linear(n_embd, vocab_size)
|
| 208 |
+
|
| 209 |
+
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
|
| 210 |
+
B, T = idx.shape
|
| 211 |
+
pos = torch.arange(T, device=idx.device)
|
| 212 |
+
x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
|
| 213 |
+
x = self.blocks(x)
|
| 214 |
+
x = self.ln_f(x)
|
| 215 |
+
logits = self.lm_head(x)
|
| 216 |
+
loss = None
|
| 217 |
+
if targets is not None:
|
| 218 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
| 219 |
+
return logits, loss
|
| 220 |
+
|
| 221 |
+
def get_sparse_linears(model):
|
| 222 |
+
return[m for m in model.modules() if isinstance(m, SparseLinear)]
|
| 223 |
+
|
| 224 |
+
# -----------------------------
|
| 225 |
+
# Chunk Masker with Annealing
|
| 226 |
+
# -----------------------------
|
| 227 |
+
class ChunkMasker:
|
| 228 |
+
def __init__(self, model: nn.Module, policy: Policy, target_fraction: float, chunk_size: int, device: str):
|
| 229 |
+
self.policy = policy
|
| 230 |
+
self.target_fraction = target_fraction
|
| 231 |
+
self.chunk_size = chunk_size
|
| 232 |
+
self.device = device
|
| 233 |
+
|
| 234 |
+
self.linears = get_sparse_linears(model)
|
| 235 |
+
self.module_to_chunk_ids = {}
|
| 236 |
+
offset = 0
|
| 237 |
+
for m in self.linears:
|
| 238 |
+
assert m.out_features % chunk_size == 0, f"out_features {m.out_features} not divisible by chunk size {chunk_size}"
|
| 239 |
+
n_chunks = m.out_features // chunk_size
|
| 240 |
+
self.module_to_chunk_ids[m] = torch.arange(offset, offset + n_chunks, device=device)
|
| 241 |
+
offset += n_chunks
|
| 242 |
+
|
| 243 |
+
self.n_chunks = offset
|
| 244 |
+
self.predicted_mass = torch.zeros(self.n_chunks, device=device)
|
| 245 |
+
self.active_chunks = torch.zeros(self.n_chunks, dtype=torch.bool, device=device)
|
| 246 |
+
|
| 247 |
+
def choose_active(self, step: int, warmup_steps: int, anneal_steps: int):
|
| 248 |
+
if step < warmup_steps:
|
| 249 |
+
current_fraction = 1.0
|
| 250 |
+
elif step < warmup_steps + anneal_steps:
|
| 251 |
+
progress = (step - warmup_steps) / anneal_steps
|
| 252 |
+
cosine_mult = 0.5 * (1.0 + math.cos(math.pi * progress))
|
| 253 |
+
current_fraction = self.target_fraction + (1.0 - self.target_fraction) * cosine_mult
|
| 254 |
+
else:
|
| 255 |
+
current_fraction = self.target_fraction
|
| 256 |
+
|
| 257 |
+
if current_fraction >= 0.999:
|
| 258 |
+
self.active_chunks.fill_(True)
|
| 259 |
+
for m, ids in self.module_to_chunk_ids.items():
|
| 260 |
+
m.active_chunks = torch.arange(len(ids), device=self.device)
|
| 261 |
+
return
|
| 262 |
+
|
| 263 |
+
k = max(1, int(current_fraction * self.n_chunks))
|
| 264 |
+
self.active_chunks.fill_(False)
|
| 265 |
+
|
| 266 |
+
if self.policy == "random":
|
| 267 |
+
self.active_chunks[torch.randperm(self.n_chunks, device=self.device)[:k]] = True
|
| 268 |
+
elif self.policy == "predicted_magnitude":
|
| 269 |
+
scores = self.predicted_mass + 1e-9 * torch.rand_like(self.predicted_mass)
|
| 270 |
+
self.active_chunks[torch.topk(scores, k=k).indices] = True
|
| 271 |
+
|
| 272 |
+
for m, ids in self.module_to_chunk_ids.items():
|
| 273 |
+
global_active = self.active_chunks[ids]
|
| 274 |
+
local_ids = torch.arange(len(ids), device=self.device)
|
| 275 |
+
m.active_chunks = local_ids[global_active]
|
| 276 |
+
|
| 277 |
+
@torch.no_grad()
|
| 278 |
+
def update_predictor(self, mass_beta=0.95):
|
| 279 |
+
current_mass = torch.zeros_like(self.predicted_mass)
|
| 280 |
+
for m, ids in self.module_to_chunk_ids.items():
|
| 281 |
+
if m.weight.grad is None: continue
|
| 282 |
+
w_sq = m.weight.grad.square().view(len(ids), self.chunk_size, -1).sum(dim=(1, 2))
|
| 283 |
+
if m.bias is not None and m.bias.grad is not None:
|
| 284 |
+
w_sq += m.bias.grad.square().view(len(ids), self.chunk_size).sum(dim=1)
|
| 285 |
+
current_mass[ids] = torch.sqrt(w_sq + 1e-30)
|
| 286 |
+
|
| 287 |
+
observed = self.active_chunks
|
| 288 |
+
self.predicted_mass[observed] = mass_beta * self.predicted_mass[observed] + (1.0 - mass_beta) * current_mass[observed]
|
| 289 |
+
|
| 290 |
+
# -----------------------------
|
| 291 |
+
# Chunked Adam
|
| 292 |
+
# -----------------------------
|
| 293 |
+
class ChunkedAdam:
|
| 294 |
+
def __init__(self, model, lr=5e-4, chunk_size=64):
|
| 295 |
+
self.model = model
|
| 296 |
+
self.lr = lr
|
| 297 |
+
self.chunk_size = chunk_size
|
| 298 |
+
self.state = {}
|
| 299 |
+
|
| 300 |
+
self.param_to_sparse_module = {}
|
| 301 |
+
for m in get_sparse_linears(model):
|
| 302 |
+
if m.weight is not None: self.param_to_sparse_module[m.weight] = m
|
| 303 |
+
if m.bias is not None: self.param_to_sparse_module[m.bias] = m
|
| 304 |
+
|
| 305 |
+
def zero_grad(self):
|
| 306 |
+
for p in self.model.parameters(): p.grad = None
|
| 307 |
+
|
| 308 |
+
@torch.no_grad()
|
| 309 |
+
def step(self):
|
| 310 |
+
for p in self.model.parameters():
|
| 311 |
+
if p.grad is None: continue
|
| 312 |
+
if p not in self.state:
|
| 313 |
+
self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
|
| 314 |
+
|
| 315 |
+
exp_avg, exp_avg_sq = self.state[p]["m"], self.state[p]["v"]
|
| 316 |
+
|
| 317 |
+
sparse_module = self.param_to_sparse_module.get(p)
|
| 318 |
+
active_chunks = getattr(sparse_module, 'active_chunks', None) if sparse_module else None
|
| 319 |
+
|
| 320 |
+
if active_chunks is None:
|
| 321 |
+
# Dense update
|
| 322 |
+
exp_avg.mul_(0.9).add_(p.grad, alpha=0.1)
|
| 323 |
+
exp_avg_sq.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
|
| 324 |
+
update = exp_avg / (torch.sqrt(exp_avg_sq) + 1e-8)
|
| 325 |
+
p.sub_(update, alpha=self.lr)
|
| 326 |
+
else:
|
| 327 |
+
# Sparse update
|
| 328 |
+
for local_c in active_chunks.tolist():
|
| 329 |
+
start = local_c * self.chunk_size
|
| 330 |
+
end = (local_c + 1) * self.chunk_size
|
| 331 |
+
|
| 332 |
+
p_chunk = p[start:end]
|
| 333 |
+
g_chunk = p.grad[start:end]
|
| 334 |
+
m_chunk = exp_avg[start:end]
|
| 335 |
+
v_chunk = exp_avg_sq[start:end]
|
| 336 |
+
|
| 337 |
+
m_chunk.mul_(0.9).add_(g_chunk, alpha=0.1)
|
| 338 |
+
v_chunk.mul_(0.999).addcmul_(g_chunk, g_chunk, value=0.001)
|
| 339 |
+
|
| 340 |
+
update = m_chunk / (torch.sqrt(v_chunk) + 1e-8)
|
| 341 |
+
p_chunk.sub_(update, alpha=self.lr)
|
| 342 |
+
|
| 343 |
+
# -----------------------------
|
| 344 |
+
# Training
|
| 345 |
+
# -----------------------------
|
| 346 |
+
def main():
|
| 347 |
+
parser = argparse.ArgumentParser()
|
| 348 |
+
parser.add_argument("--steps", type=int, default=1000)
|
| 349 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
| 350 |
+
parser.add_argument("--block_size", type=int, default=256)
|
| 351 |
+
parser.add_argument("--n_layer", type=int, default=6)
|
| 352 |
+
parser.add_argument("--n_head", type=int, default=8)
|
| 353 |
+
parser.add_argument("--n_embd", type=int, default=512)
|
| 354 |
+
parser.add_argument("--chunk_size", type=int, default=64)
|
| 355 |
+
parser.add_argument("--active_fraction", type=float, default=0.10)
|
| 356 |
+
parser.add_argument("--warmup_steps", type=int, default=50)
|
| 357 |
+
parser.add_argument("--anneal_steps", type=int, default=200)
|
| 358 |
+
parser.add_argument("--device", type=str, default="mps")
|
| 359 |
+
parser.add_argument("--benchmark_sync", action="store_true")
|
| 360 |
+
args = parser.parse_args()
|
| 361 |
+
|
| 362 |
+
corpus = ShakespeareCorpus(args.block_size, args.device)
|
| 363 |
+
|
| 364 |
+
modes =[
|
| 365 |
+
("dense_baseline", "dense_baseline"),
|
| 366 |
+
("predicted_magnitude", "sparse_dW_full_dX"),
|
| 367 |
+
("predicted_magnitude", "sparse_dW_sparse_dX")
|
| 368 |
+
]
|
| 369 |
+
|
| 370 |
+
print(f"\nModel: {args.n_layer} layers, {args.n_embd} d_model, {args.chunk_size} chunk_size")
|
| 371 |
+
print(f"Batch: {args.batch_size}, Block: {args.block_size}. Target Active: {args.active_fraction*100}%")
|
| 372 |
+
print(f"Annealing: {args.warmup_steps} warmup steps, {args.anneal_steps} anneal steps.\n")
|
| 373 |
+
print(f"{'Run':>20s} | {'Time (s)':>10s} | {'Step (ms)':>10s} | {'Val Loss':>8s}")
|
| 374 |
+
print("-" * 55)
|
| 375 |
+
|
| 376 |
+
for policy, bwd_mode in modes:
|
| 377 |
+
set_seed(42)
|
| 378 |
+
model = GPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, 0.1).to(args.device)
|
| 379 |
+
|
| 380 |
+
for m in get_sparse_linears(model):
|
| 381 |
+
m.chunk_size = args.chunk_size
|
| 382 |
+
|
| 383 |
+
masker = ChunkMasker(model, policy, args.active_fraction, args.chunk_size, args.device) if policy != "dense_baseline" else None
|
| 384 |
+
opt = ChunkedAdam(model, lr=5e-4, chunk_size=args.chunk_size)
|
| 385 |
+
|
| 386 |
+
if args.benchmark_sync: sync_device(args.device)
|
| 387 |
+
|
| 388 |
+
t0 = time.perf_counter()
|
| 389 |
+
measured_steps = args.steps
|
| 390 |
+
|
| 391 |
+
for step in range(args.steps):
|
| 392 |
+
if step == args.warmup_steps + args.anneal_steps:
|
| 393 |
+
if args.benchmark_sync: sync_device(args.device)
|
| 394 |
+
t0 = time.perf_counter()
|
| 395 |
+
measured_steps = args.steps - step
|
| 396 |
+
|
| 397 |
+
x, y = corpus.get_batch("train", args.batch_size, generator=make_cpu_generator(step))
|
| 398 |
+
|
| 399 |
+
if masker:
|
| 400 |
+
masker.choose_active(step, warmup_steps=args.warmup_steps, anneal_steps=args.anneal_steps)
|
| 401 |
+
for m in get_sparse_linears(model):
|
| 402 |
+
m.sparse_enabled = True
|
| 403 |
+
m.sparse_dx = (bwd_mode == "sparse_dW_sparse_dX")
|
| 404 |
+
else:
|
| 405 |
+
for m in get_sparse_linears(model):
|
| 406 |
+
m.sparse_enabled = False
|
| 407 |
+
m.active_chunks = None
|
| 408 |
+
|
| 409 |
+
opt.zero_grad()
|
| 410 |
+
_, loss = model(x, y)
|
| 411 |
+
loss.backward()
|
| 412 |
+
|
| 413 |
+
if masker:
|
| 414 |
+
masker.update_predictor()
|
| 415 |
+
|
| 416 |
+
opt.step()
|
| 417 |
+
|
| 418 |
+
# Optional: Print progress every 100 steps
|
| 419 |
+
if step % 200 == 0:
|
| 420 |
+
print(f" [Progress] {bwd_mode} step {step}/{args.steps} | Loss: {loss.item():.4f}", end="\r")
|
| 421 |
+
|
| 422 |
+
if args.benchmark_sync: sync_device(args.device)
|
| 423 |
+
t_elapsed = time.perf_counter() - t0
|
| 424 |
+
|
| 425 |
+
# Eval loss
|
| 426 |
+
model.eval()
|
| 427 |
+
with torch.no_grad():
|
| 428 |
+
# Eval loss
|
| 429 |
+
model.eval()
|
| 430 |
+
with torch.no_grad():
|
| 431 |
+
val_x, val_y = corpus.get_batch("val", args.batch_size, generator=make_cpu_generator(999))
|
| 432 |
+
_, val_loss = model(val_x, val_y)
|
| 433 |
+
|
| 434 |
+
# Clear the progress line
|
| 435 |
+
print(" " * 60, end="\r")
|
| 436 |
+
|
| 437 |
+
bwd_str = bwd_mode if bwd_mode == "dense_baseline" else ("sparse_full_dX" if "full_dX" in bwd_mode else "sparse_sparse_dX")
|
| 438 |
+
print(f"{bwd_str:>20s} | {t_elapsed:10.2f} | {1000*t_elapsed/max(1, measured_steps):10.2f} | {val_loss.item():8.4f}")
|
| 439 |
+
|
| 440 |
+
if __name__ == "__main__":
|
| 441 |
+
main()
|
experiments/sparse_transformer_v15_inactive_prediction.py
ADDED
|
@@ -0,0 +1,729 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sparse Transformer v15: Inactive-Update Prediction Diagnostics.
|
| 3 |
+
|
| 4 |
+
Tests two simple ideas:
|
| 5 |
+
|
| 6 |
+
1. Correlated-neighbor prediction:
|
| 7 |
+
Use active chunks as sensors. For each inactive chunk, find historically
|
| 8 |
+
correlated active chunks and predict its update magnitude from them.
|
| 9 |
+
|
| 10 |
+
2. Graph / boundary interpolation:
|
| 11 |
+
Treat chunks as nodes in a learned similarity graph. Active chunks are
|
| 12 |
+
boundary values. Inactive chunk magnitudes are filled in by diffusion.
|
| 13 |
+
|
| 14 |
+
This is intentionally a diagnostic script, not a speed benchmark.
|
| 15 |
+
It computes dense gradients every step so we can measure whether inactive
|
| 16 |
+
updates are predictable.
|
| 17 |
+
|
| 18 |
+
Run:
|
| 19 |
+
python3 sparse_transformer_v15_inactive_prediction.py --device mps --benchmark_sync
|
| 20 |
+
|
| 21 |
+
Good first runs:
|
| 22 |
+
python3 sparse_transformer_v15_inactive_prediction.py --device mps --steps 300 --n_embd 512
|
| 23 |
+
python3 sparse_transformer_v15_inactive_prediction.py --device mps --steps 300 --n_embd 1024
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import argparse
|
| 27 |
+
import math
|
| 28 |
+
import random
|
| 29 |
+
import time
|
| 30 |
+
from typing import Dict, List, Literal, Optional, Tuple
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
|
| 34 |
+
torch.set_num_threads(1)
|
| 35 |
+
import torch.nn as nn
|
| 36 |
+
import torch.nn.functional as F
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
Policy = Literal["predicted_magnitude", "random"]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def sync_device(device: str) -> None:
|
| 43 |
+
if device == "cuda" and torch.cuda.is_available():
|
| 44 |
+
torch.cuda.synchronize()
|
| 45 |
+
elif device == "mps" and hasattr(torch, "mps"):
|
| 46 |
+
torch.mps.synchronize()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def set_seed(seed: int) -> None:
|
| 50 |
+
random.seed(seed)
|
| 51 |
+
torch.manual_seed(seed)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def make_cpu_generator(seed: int) -> torch.Generator:
|
| 55 |
+
gen = torch.Generator(device="cpu")
|
| 56 |
+
gen.manual_seed(seed)
|
| 57 |
+
return gen
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# -----------------------------
|
| 61 |
+
# Data
|
| 62 |
+
# -----------------------------
|
| 63 |
+
|
| 64 |
+
def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
|
| 65 |
+
rng = random.Random(seed)
|
| 66 |
+
words = [
|
| 67 |
+
"ada", "turing", "grace", "lovelace", "gradients",
|
| 68 |
+
"tokens", "circuits", "features", "boldly", "strangely",
|
| 69 |
+
"matrix", "attention", "kernel", "entropy", "signal",
|
| 70 |
+
]
|
| 71 |
+
return "\n".join(
|
| 72 |
+
" ".join(rng.choices(words, k=rng.randint(4, 10))) + "."
|
| 73 |
+
for _ in range(n_sentences)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class CharCorpus:
|
| 78 |
+
def __init__(self, text: str, block_size: int, device: str):
|
| 79 |
+
chars = sorted(set(text))
|
| 80 |
+
self.stoi = {ch: i for i, ch in enumerate(chars)}
|
| 81 |
+
self.itos = {i: ch for ch, i in self.stoi.items()}
|
| 82 |
+
self.vocab_size = len(chars)
|
| 83 |
+
self.block_size = block_size
|
| 84 |
+
self.device = device
|
| 85 |
+
|
| 86 |
+
data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
|
| 87 |
+
self.train_data = data[: int(0.9 * len(data))]
|
| 88 |
+
self.val_data = data[int(0.9 * len(data)) :]
|
| 89 |
+
|
| 90 |
+
def get_batch(
|
| 91 |
+
self,
|
| 92 |
+
split: str,
|
| 93 |
+
batch_size: int,
|
| 94 |
+
generator: Optional[torch.Generator] = None,
|
| 95 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 96 |
+
data = self.train_data if split == "train" else self.val_data
|
| 97 |
+
ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator)
|
| 98 |
+
x = torch.stack([data[i : i + self.block_size] for i in ix])
|
| 99 |
+
y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
|
| 100 |
+
return x.to(self.device), y.to(self.device)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# -----------------------------
|
| 104 |
+
# Model
|
| 105 |
+
# -----------------------------
|
| 106 |
+
|
| 107 |
+
class SparseLinear(nn.Linear):
|
| 108 |
+
"""Name retained for compatibility with earlier experiments.
|
| 109 |
+
|
| 110 |
+
In this diagnostic script, backward is dense. We only use chunk masks
|
| 111 |
+
analytically after gradients are computed.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class CausalSelfAttention(nn.Module):
|
| 116 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 117 |
+
super().__init__()
|
| 118 |
+
assert n_embd % n_head == 0
|
| 119 |
+
self.n_head = n_head
|
| 120 |
+
self.head_dim = n_embd // n_head
|
| 121 |
+
self.c_attn = SparseLinear(n_embd, 3 * n_embd)
|
| 122 |
+
self.c_proj = SparseLinear(n_embd, n_embd)
|
| 123 |
+
self.dropout = nn.Dropout(dropout)
|
| 124 |
+
self.register_buffer(
|
| 125 |
+
"mask",
|
| 126 |
+
torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size),
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 130 |
+
B, T, C = x.shape
|
| 131 |
+
qkv = self.c_attn(x)
|
| 132 |
+
q, k, v = qkv.split(C, dim=2)
|
| 133 |
+
|
| 134 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 135 |
+
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 136 |
+
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 137 |
+
|
| 138 |
+
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 139 |
+
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
|
| 140 |
+
att = F.softmax(att, dim=-1)
|
| 141 |
+
att = self.dropout(att)
|
| 142 |
+
|
| 143 |
+
y = att @ v
|
| 144 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 145 |
+
return self.c_proj(y)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class FeedForward(nn.Module):
|
| 149 |
+
def __init__(self, n_embd: int, dropout: float):
|
| 150 |
+
super().__init__()
|
| 151 |
+
self.c_fc = SparseLinear(n_embd, 4 * n_embd)
|
| 152 |
+
self.c_proj = SparseLinear(4 * n_embd, n_embd)
|
| 153 |
+
self.dropout = nn.Dropout(dropout)
|
| 154 |
+
|
| 155 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 156 |
+
return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class Block(nn.Module):
|
| 160 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 161 |
+
super().__init__()
|
| 162 |
+
self.ln1 = nn.LayerNorm(n_embd)
|
| 163 |
+
self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
|
| 164 |
+
self.ln2 = nn.LayerNorm(n_embd)
|
| 165 |
+
self.mlp = FeedForward(n_embd, dropout)
|
| 166 |
+
|
| 167 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 168 |
+
x = x + self.attn(self.ln1(x))
|
| 169 |
+
x = x + self.mlp(self.ln2(x))
|
| 170 |
+
return x
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class MiniGPT(nn.Module):
|
| 174 |
+
def __init__(
|
| 175 |
+
self,
|
| 176 |
+
vocab_size: int,
|
| 177 |
+
block_size: int,
|
| 178 |
+
n_layer: int,
|
| 179 |
+
n_head: int,
|
| 180 |
+
n_embd: int,
|
| 181 |
+
dropout: float,
|
| 182 |
+
):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.block_size = block_size
|
| 185 |
+
self.tok_emb = nn.Embedding(vocab_size, n_embd)
|
| 186 |
+
self.pos_emb = nn.Embedding(block_size, n_embd)
|
| 187 |
+
self.blocks = nn.Sequential(
|
| 188 |
+
*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]
|
| 189 |
+
)
|
| 190 |
+
self.ln_f = nn.LayerNorm(n_embd)
|
| 191 |
+
self.lm_head = nn.Linear(n_embd, vocab_size)
|
| 192 |
+
|
| 193 |
+
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
|
| 194 |
+
B, T = idx.shape
|
| 195 |
+
pos = torch.arange(T, device=idx.device)
|
| 196 |
+
x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
|
| 197 |
+
x = self.blocks(x)
|
| 198 |
+
x = self.ln_f(x)
|
| 199 |
+
logits = self.lm_head(x)
|
| 200 |
+
|
| 201 |
+
loss = None
|
| 202 |
+
if targets is not None:
|
| 203 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
| 204 |
+
return logits, loss
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def get_sparse_linears(model: nn.Module) -> List[SparseLinear]:
|
| 208 |
+
return [m for m in model.modules() if isinstance(m, SparseLinear)]
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# -----------------------------
|
| 212 |
+
# Chunk geometry and diagnostics
|
| 213 |
+
# -----------------------------
|
| 214 |
+
|
| 215 |
+
class ChunkMap:
|
| 216 |
+
def __init__(self, model: nn.Module, chunk_size: int, device: str):
|
| 217 |
+
self.model = model
|
| 218 |
+
self.chunk_size = chunk_size
|
| 219 |
+
self.device = device
|
| 220 |
+
self.linears = get_sparse_linears(model)
|
| 221 |
+
|
| 222 |
+
self.module_to_chunk_ids: Dict[nn.Module, torch.Tensor] = {}
|
| 223 |
+
self.chunk_to_module_local: List[Tuple[nn.Module, int]] = []
|
| 224 |
+
|
| 225 |
+
offset = 0
|
| 226 |
+
for m in self.linears:
|
| 227 |
+
assert m.out_features % chunk_size == 0, (
|
| 228 |
+
f"out_features {m.out_features} not divisible by chunk_size {chunk_size}"
|
| 229 |
+
)
|
| 230 |
+
n_chunks = m.out_features // chunk_size
|
| 231 |
+
ids = torch.arange(offset, offset + n_chunks, device=device)
|
| 232 |
+
self.module_to_chunk_ids[m] = ids
|
| 233 |
+
for local_c in range(n_chunks):
|
| 234 |
+
self.chunk_to_module_local.append((m, local_c))
|
| 235 |
+
offset += n_chunks
|
| 236 |
+
|
| 237 |
+
self.n_chunks = offset
|
| 238 |
+
self.predicted_mass = torch.zeros(self.n_chunks, device=device)
|
| 239 |
+
self.direction_ema: List[Optional[torch.Tensor]] = [None for _ in range(self.n_chunks)]
|
| 240 |
+
|
| 241 |
+
# Histories for correlation and graph similarities.
|
| 242 |
+
self.mass_history: List[torch.Tensor] = []
|
| 243 |
+
|
| 244 |
+
def choose_active(
|
| 245 |
+
self,
|
| 246 |
+
step: int,
|
| 247 |
+
warmup_steps: int,
|
| 248 |
+
active_fraction: float,
|
| 249 |
+
policy: Policy,
|
| 250 |
+
) -> torch.Tensor:
|
| 251 |
+
if step < warmup_steps:
|
| 252 |
+
return torch.ones(self.n_chunks, dtype=torch.bool, device=self.device)
|
| 253 |
+
|
| 254 |
+
k = max(1, int(active_fraction * self.n_chunks))
|
| 255 |
+
mask = torch.zeros(self.n_chunks, dtype=torch.bool, device=self.device)
|
| 256 |
+
|
| 257 |
+
if policy == "random":
|
| 258 |
+
idx = torch.randperm(self.n_chunks, device=self.device)[:k]
|
| 259 |
+
else:
|
| 260 |
+
scores = self.predicted_mass + 1e-9 * torch.rand_like(self.predicted_mass)
|
| 261 |
+
idx = torch.topk(scores, k=k).indices
|
| 262 |
+
|
| 263 |
+
mask[idx] = True
|
| 264 |
+
return mask
|
| 265 |
+
|
| 266 |
+
@torch.no_grad()
|
| 267 |
+
def chunk_gradient_vectors(self) -> List[torch.Tensor]:
|
| 268 |
+
vecs: List[torch.Tensor] = []
|
| 269 |
+
for m, local_c in self.chunk_to_module_local:
|
| 270 |
+
start = local_c * self.chunk_size
|
| 271 |
+
end = (local_c + 1) * self.chunk_size
|
| 272 |
+
|
| 273 |
+
parts = []
|
| 274 |
+
if m.weight.grad is None:
|
| 275 |
+
parts.append(torch.zeros_like(m.weight[start:end]).flatten())
|
| 276 |
+
else:
|
| 277 |
+
parts.append(m.weight.grad[start:end].detach().flatten())
|
| 278 |
+
|
| 279 |
+
if m.bias is not None:
|
| 280 |
+
if m.bias.grad is None:
|
| 281 |
+
parts.append(torch.zeros_like(m.bias[start:end]).flatten())
|
| 282 |
+
else:
|
| 283 |
+
parts.append(m.bias.grad[start:end].detach().flatten())
|
| 284 |
+
|
| 285 |
+
vecs.append(torch.cat(parts))
|
| 286 |
+
return vecs
|
| 287 |
+
|
| 288 |
+
@torch.no_grad()
|
| 289 |
+
def chunk_masses_from_vecs(self, vecs: List[torch.Tensor]) -> torch.Tensor:
|
| 290 |
+
return torch.stack([v.norm() for v in vecs]).to(self.device)
|
| 291 |
+
|
| 292 |
+
@torch.no_grad()
|
| 293 |
+
def update_predictor(
|
| 294 |
+
self,
|
| 295 |
+
active_mask: torch.Tensor,
|
| 296 |
+
vecs: List[torch.Tensor],
|
| 297 |
+
mass_beta: float = 0.95,
|
| 298 |
+
dir_beta: float = 0.95,
|
| 299 |
+
store_history: bool = True,
|
| 300 |
+
) -> torch.Tensor:
|
| 301 |
+
masses = self.chunk_masses_from_vecs(vecs)
|
| 302 |
+
|
| 303 |
+
observed = active_mask
|
| 304 |
+
# First observation should initialize directly, not get shrunk by beta.
|
| 305 |
+
never_seen = observed & (self.predicted_mass == 0)
|
| 306 |
+
already_seen = observed & ~never_seen
|
| 307 |
+
self.predicted_mass[never_seen] = masses[never_seen]
|
| 308 |
+
self.predicted_mass[already_seen] = (
|
| 309 |
+
mass_beta * self.predicted_mass[already_seen]
|
| 310 |
+
+ (1.0 - mass_beta) * masses[already_seen]
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
for i, is_active in enumerate(observed.tolist()):
|
| 314 |
+
if not is_active:
|
| 315 |
+
continue
|
| 316 |
+
v = vecs[i]
|
| 317 |
+
n = v.norm()
|
| 318 |
+
if n <= 1e-12:
|
| 319 |
+
continue
|
| 320 |
+
unit = v / n
|
| 321 |
+
if self.direction_ema[i] is None:
|
| 322 |
+
self.direction_ema[i] = unit.detach().clone()
|
| 323 |
+
else:
|
| 324 |
+
self.direction_ema[i] = (
|
| 325 |
+
dir_beta * self.direction_ema[i] + (1.0 - dir_beta) * unit
|
| 326 |
+
)
|
| 327 |
+
self.direction_ema[i] = self.direction_ema[i] / (self.direction_ema[i].norm() + 1e-12)
|
| 328 |
+
|
| 329 |
+
if store_history:
|
| 330 |
+
self.mass_history.append(masses.detach().clone())
|
| 331 |
+
max_hist = 128
|
| 332 |
+
if len(self.mass_history) > max_hist:
|
| 333 |
+
self.mass_history = self.mass_history[-max_hist:]
|
| 334 |
+
|
| 335 |
+
return masses
|
| 336 |
+
|
| 337 |
+
def layer_aware_masks(self) -> List[torch.Tensor]:
|
| 338 |
+
masks = []
|
| 339 |
+
for m, ids in self.module_to_chunk_ids.items():
|
| 340 |
+
mask = torch.zeros(self.n_chunks, dtype=torch.bool, device=self.device)
|
| 341 |
+
mask[ids] = True
|
| 342 |
+
masks.append(mask)
|
| 343 |
+
return masks
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def dense_cosine_from_vecs(a: List[torch.Tensor], b: List[torch.Tensor]) -> float:
|
| 347 |
+
va = torch.cat([x.flatten() for x in a])
|
| 348 |
+
vb = torch.cat([x.flatten() for x in b])
|
| 349 |
+
return float(F.cosine_similarity(va, vb, dim=0).item())
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def mse_reduction_vs_zero(true_vecs: List[torch.Tensor], pred_vecs: List[torch.Tensor], mask: torch.Tensor) -> float:
|
| 353 |
+
idxs = torch.nonzero(mask, as_tuple=False).flatten().tolist()
|
| 354 |
+
if not idxs:
|
| 355 |
+
return float("nan")
|
| 356 |
+
true = torch.cat([true_vecs[i].flatten() for i in idxs])
|
| 357 |
+
pred = torch.cat([pred_vecs[i].flatten() for i in idxs])
|
| 358 |
+
zero_mse = torch.mean(true.square())
|
| 359 |
+
pred_mse = torch.mean((true - pred).square())
|
| 360 |
+
return float((1.0 - pred_mse / (zero_mse + 1e-12)).item())
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def active_only_prediction(true_vecs: List[torch.Tensor], active_mask: torch.Tensor) -> List[torch.Tensor]:
|
| 364 |
+
out = []
|
| 365 |
+
for i, v in enumerate(true_vecs):
|
| 366 |
+
out.append(v.clone() if bool(active_mask[i]) else torch.zeros_like(v))
|
| 367 |
+
return out
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def ema_direction_prediction(
|
| 371 |
+
cmap: ChunkMap,
|
| 372 |
+
true_vecs: List[torch.Tensor],
|
| 373 |
+
active_mask: torch.Tensor,
|
| 374 |
+
inactive_magnitudes: torch.Tensor,
|
| 375 |
+
) -> List[torch.Tensor]:
|
| 376 |
+
out = []
|
| 377 |
+
for i, v in enumerate(true_vecs):
|
| 378 |
+
if bool(active_mask[i]):
|
| 379 |
+
out.append(v.clone())
|
| 380 |
+
else:
|
| 381 |
+
direction = cmap.direction_ema[i]
|
| 382 |
+
if direction is None:
|
| 383 |
+
out.append(torch.zeros_like(v))
|
| 384 |
+
else:
|
| 385 |
+
out.append(direction.to(v.device, v.dtype) * inactive_magnitudes[i])
|
| 386 |
+
return out
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def build_mass_similarity(cmap: ChunkMap, min_history: int = 8) -> Optional[torch.Tensor]:
|
| 390 |
+
if len(cmap.mass_history) < min_history:
|
| 391 |
+
return None
|
| 392 |
+
|
| 393 |
+
H = torch.stack(cmap.mass_history, dim=0) # [history, chunks]
|
| 394 |
+
H = H - H.mean(dim=0, keepdim=True)
|
| 395 |
+
H = H / (H.std(dim=0, keepdim=True) + 1e-6)
|
| 396 |
+
|
| 397 |
+
S = (H.T @ H) / max(1, H.shape[0] - 1)
|
| 398 |
+
S = torch.clamp(S, min=0.0)
|
| 399 |
+
|
| 400 |
+
# Remove self similarity.
|
| 401 |
+
S.fill_diagonal_(0.0)
|
| 402 |
+
|
| 403 |
+
# Layer-aware block diagonal: avoid mixing unrelated layers by default.
|
| 404 |
+
layer_masks = cmap.layer_aware_masks()
|
| 405 |
+
layer_allowed = torch.zeros_like(S, dtype=torch.bool)
|
| 406 |
+
for mask in layer_masks:
|
| 407 |
+
layer_allowed |= mask[:, None] & mask[None, :]
|
| 408 |
+
S = torch.where(layer_allowed, S, torch.zeros_like(S))
|
| 409 |
+
|
| 410 |
+
return S
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def knn_magnitude_prediction(
|
| 414 |
+
cmap: ChunkMap,
|
| 415 |
+
active_mask: torch.Tensor,
|
| 416 |
+
true_masses: torch.Tensor,
|
| 417 |
+
k_neighbors: int = 3,
|
| 418 |
+
) -> torch.Tensor:
|
| 419 |
+
"""Predict inactive magnitudes as weighted average of correlated active magnitudes."""
|
| 420 |
+
S = build_mass_similarity(cmap)
|
| 421 |
+
if S is None:
|
| 422 |
+
pred = cmap.predicted_mass.clone()
|
| 423 |
+
pred[active_mask] = true_masses[active_mask]
|
| 424 |
+
return pred
|
| 425 |
+
|
| 426 |
+
pred = torch.zeros_like(true_masses)
|
| 427 |
+
pred[active_mask] = true_masses[active_mask]
|
| 428 |
+
|
| 429 |
+
active_idx = torch.nonzero(active_mask, as_tuple=False).flatten()
|
| 430 |
+
inactive_idx = torch.nonzero(~active_mask, as_tuple=False).flatten()
|
| 431 |
+
|
| 432 |
+
if active_idx.numel() == 0:
|
| 433 |
+
return pred
|
| 434 |
+
|
| 435 |
+
for i in inactive_idx.tolist():
|
| 436 |
+
weights = S[i, active_idx]
|
| 437 |
+
if weights.sum() <= 1e-12:
|
| 438 |
+
pred[i] = cmap.predicted_mass[i]
|
| 439 |
+
continue
|
| 440 |
+
|
| 441 |
+
kk = min(k_neighbors, weights.numel())
|
| 442 |
+
top = torch.topk(weights, k=kk)
|
| 443 |
+
w = top.values
|
| 444 |
+
aidx = active_idx[top.indices]
|
| 445 |
+
pred[i] = (w * true_masses[aidx]).sum() / (w.sum() + 1e-12)
|
| 446 |
+
|
| 447 |
+
return pred
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def graph_diffusion_magnitude_prediction(
|
| 451 |
+
cmap: ChunkMap,
|
| 452 |
+
active_mask: torch.Tensor,
|
| 453 |
+
true_masses: torch.Tensor,
|
| 454 |
+
diffusion_steps: int = 8,
|
| 455 |
+
alpha: float = 0.7,
|
| 456 |
+
) -> torch.Tensor:
|
| 457 |
+
"""Boundary-value style magnitude interpolation over a learned similarity graph.
|
| 458 |
+
|
| 459 |
+
Active nodes are clamped to observed true magnitudes. Inactive nodes diffuse
|
| 460 |
+
toward graph-neighbor values.
|
| 461 |
+
"""
|
| 462 |
+
S = build_mass_similarity(cmap)
|
| 463 |
+
if S is None:
|
| 464 |
+
pred = cmap.predicted_mass.clone()
|
| 465 |
+
pred[active_mask] = true_masses[active_mask]
|
| 466 |
+
return pred
|
| 467 |
+
|
| 468 |
+
W = S / (S.sum(dim=1, keepdim=True) + 1e-12)
|
| 469 |
+
|
| 470 |
+
pred = cmap.predicted_mass.clone()
|
| 471 |
+
pred[active_mask] = true_masses[active_mask]
|
| 472 |
+
|
| 473 |
+
for _ in range(diffusion_steps):
|
| 474 |
+
proposal = W @ pred
|
| 475 |
+
pred = alpha * proposal + (1.0 - alpha) * pred
|
| 476 |
+
pred[active_mask] = true_masses[active_mask]
|
| 477 |
+
|
| 478 |
+
return torch.clamp(pred, min=0.0)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
# -----------------------------
|
| 482 |
+
# Optimizer
|
| 483 |
+
# -----------------------------
|
| 484 |
+
|
| 485 |
+
class SimpleAdam:
|
| 486 |
+
"""Small Adam-like optimizer for diagnostics.
|
| 487 |
+
|
| 488 |
+
This is intentionally simple and consistent across runs. It is not trying
|
| 489 |
+
to be production AdamW.
|
| 490 |
+
"""
|
| 491 |
+
|
| 492 |
+
def __init__(self, model: nn.Module, lr: float = 3e-4):
|
| 493 |
+
self.model = model
|
| 494 |
+
self.lr = lr
|
| 495 |
+
self.state: Dict[torch.nn.Parameter, Dict[str, torch.Tensor]] = {}
|
| 496 |
+
|
| 497 |
+
def zero_grad(self):
|
| 498 |
+
for p in self.model.parameters():
|
| 499 |
+
p.grad = None
|
| 500 |
+
|
| 501 |
+
@torch.no_grad()
|
| 502 |
+
def step(self):
|
| 503 |
+
for p in self.model.parameters():
|
| 504 |
+
if p.grad is None:
|
| 505 |
+
continue
|
| 506 |
+
if p not in self.state:
|
| 507 |
+
self.state[p] = {
|
| 508 |
+
"m": torch.zeros_like(p),
|
| 509 |
+
"v": torch.zeros_like(p),
|
| 510 |
+
}
|
| 511 |
+
m = self.state[p]["m"]
|
| 512 |
+
v = self.state[p]["v"]
|
| 513 |
+
m.mul_(0.9).add_(p.grad, alpha=0.1)
|
| 514 |
+
v.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
|
| 515 |
+
p.sub_(m / (torch.sqrt(v) + 1e-8), alpha=self.lr)
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
# -----------------------------
|
| 519 |
+
# Apply chunk-gradient predictions
|
| 520 |
+
# -----------------------------
|
| 521 |
+
|
| 522 |
+
@torch.no_grad()
|
| 523 |
+
def install_chunk_prediction_as_grads(
|
| 524 |
+
cmap: ChunkMap,
|
| 525 |
+
pred_vecs: List[torch.Tensor],
|
| 526 |
+
):
|
| 527 |
+
"""Overwrite SparseLinear weight/bias grads from predicted chunk vectors.
|
| 528 |
+
|
| 529 |
+
Non-SparseLinear parameters keep their dense gradients.
|
| 530 |
+
"""
|
| 531 |
+
for m, ids in cmap.module_to_chunk_ids.items():
|
| 532 |
+
if m.weight.grad is None:
|
| 533 |
+
continue
|
| 534 |
+
m.weight.grad.zero_()
|
| 535 |
+
if m.bias is not None and m.bias.grad is not None:
|
| 536 |
+
m.bias.grad.zero_()
|
| 537 |
+
|
| 538 |
+
for local_c, global_id in enumerate(ids.tolist()):
|
| 539 |
+
start = local_c * cmap.chunk_size
|
| 540 |
+
end = (local_c + 1) * cmap.chunk_size
|
| 541 |
+
|
| 542 |
+
v = pred_vecs[global_id]
|
| 543 |
+
w_numel = cmap.chunk_size * m.weight.shape[1]
|
| 544 |
+
w_flat = v[:w_numel]
|
| 545 |
+
m.weight.grad[start:end] = w_flat.view(cmap.chunk_size, m.weight.shape[1])
|
| 546 |
+
|
| 547 |
+
if m.bias is not None and m.bias.grad is not None:
|
| 548 |
+
b_flat = v[w_numel:]
|
| 549 |
+
if b_flat.numel() > 0:
|
| 550 |
+
m.bias.grad[start:end] = b_flat.view(cmap.chunk_size)
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
# -----------------------------
|
| 554 |
+
# Training / diagnostics
|
| 555 |
+
# -----------------------------
|
| 556 |
+
|
| 557 |
+
def evaluate(model: nn.Module, corpus: CharCorpus, batch_size: int, seed: int) -> float:
|
| 558 |
+
model.eval()
|
| 559 |
+
with torch.no_grad():
|
| 560 |
+
x, y = corpus.get_batch("val", batch_size, generator=make_cpu_generator(seed))
|
| 561 |
+
_, loss = model(x, y)
|
| 562 |
+
model.train()
|
| 563 |
+
return float(loss.item())
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def run_experiment(
|
| 567 |
+
mode: str,
|
| 568 |
+
device: str,
|
| 569 |
+
steps: int,
|
| 570 |
+
batch_size: int,
|
| 571 |
+
block_size: int,
|
| 572 |
+
n_layer: int,
|
| 573 |
+
n_head: int,
|
| 574 |
+
n_embd: int,
|
| 575 |
+
chunk_size: int,
|
| 576 |
+
active_fraction: float,
|
| 577 |
+
warmup_steps: int,
|
| 578 |
+
policy: Policy,
|
| 579 |
+
benchmark_sync: bool,
|
| 580 |
+
) -> Dict[str, float]:
|
| 581 |
+
set_seed(42)
|
| 582 |
+
corpus = CharCorpus(make_synthetic_corpus(), block_size, device)
|
| 583 |
+
model = MiniGPT(corpus.vocab_size, block_size, n_layer, n_head, n_embd, 0.0).to(device)
|
| 584 |
+
opt = SimpleAdam(model, lr=3e-4)
|
| 585 |
+
cmap = ChunkMap(model, chunk_size=chunk_size, device=device)
|
| 586 |
+
|
| 587 |
+
metric_rows = []
|
| 588 |
+
|
| 589 |
+
if benchmark_sync:
|
| 590 |
+
sync_device(device)
|
| 591 |
+
t0 = time.perf_counter()
|
| 592 |
+
|
| 593 |
+
for step in range(steps):
|
| 594 |
+
x, y = corpus.get_batch("train", batch_size, generator=make_cpu_generator(step))
|
| 595 |
+
|
| 596 |
+
opt.zero_grad()
|
| 597 |
+
_, loss = model(x, y)
|
| 598 |
+
loss.backward()
|
| 599 |
+
|
| 600 |
+
true_vecs = cmap.chunk_gradient_vectors()
|
| 601 |
+
true_masses = cmap.chunk_masses_from_vecs(true_vecs)
|
| 602 |
+
|
| 603 |
+
active_mask = cmap.choose_active(
|
| 604 |
+
step=step,
|
| 605 |
+
warmup_steps=warmup_steps,
|
| 606 |
+
active_fraction=active_fraction,
|
| 607 |
+
policy=policy,
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
if step < warmup_steps or mode == "dense":
|
| 611 |
+
pred_vecs = [v.clone() for v in true_vecs]
|
| 612 |
+
else:
|
| 613 |
+
active_only_vecs = active_only_prediction(true_vecs, active_mask)
|
| 614 |
+
|
| 615 |
+
if mode == "active_only":
|
| 616 |
+
pred_vecs = active_only_vecs
|
| 617 |
+
|
| 618 |
+
elif mode == "knn_magnitude":
|
| 619 |
+
pred_masses = knn_magnitude_prediction(cmap, active_mask, true_masses)
|
| 620 |
+
pred_vecs = ema_direction_prediction(cmap, true_vecs, active_mask, pred_masses)
|
| 621 |
+
|
| 622 |
+
elif mode == "graph_diffusion":
|
| 623 |
+
pred_masses = graph_diffusion_magnitude_prediction(cmap, active_mask, true_masses)
|
| 624 |
+
pred_vecs = ema_direction_prediction(cmap, true_vecs, active_mask, pred_masses)
|
| 625 |
+
|
| 626 |
+
elif mode == "ema_inactive":
|
| 627 |
+
pred_masses = cmap.predicted_mass.clone()
|
| 628 |
+
pred_masses[active_mask] = true_masses[active_mask]
|
| 629 |
+
pred_vecs = ema_direction_prediction(cmap, true_vecs, active_mask, pred_masses)
|
| 630 |
+
|
| 631 |
+
else:
|
| 632 |
+
raise ValueError(f"Unknown mode: {mode}")
|
| 633 |
+
|
| 634 |
+
install_chunk_prediction_as_grads(cmap, pred_vecs)
|
| 635 |
+
|
| 636 |
+
if step % 25 == 0:
|
| 637 |
+
inactive_mask = ~active_mask
|
| 638 |
+
row = {
|
| 639 |
+
"cosine_full": dense_cosine_from_vecs(true_vecs, pred_vecs),
|
| 640 |
+
"inactive_mse_reduction": mse_reduction_vs_zero(true_vecs, pred_vecs, inactive_mask),
|
| 641 |
+
"active_frac": float(active_mask.float().mean().item()),
|
| 642 |
+
"val": evaluate(model, corpus, batch_size, seed=999 + step),
|
| 643 |
+
}
|
| 644 |
+
metric_rows.append(row)
|
| 645 |
+
|
| 646 |
+
# Update predictor after measuring and installing predicted grads.
|
| 647 |
+
# Use true active chunk observations only, mimicking sparse observation.
|
| 648 |
+
cmap.update_predictor(active_mask, true_vecs, store_history=True)
|
| 649 |
+
|
| 650 |
+
opt.step()
|
| 651 |
+
|
| 652 |
+
if benchmark_sync:
|
| 653 |
+
sync_device(device)
|
| 654 |
+
elapsed = time.perf_counter() - t0
|
| 655 |
+
|
| 656 |
+
val_loss = evaluate(model, corpus, batch_size, seed=12345)
|
| 657 |
+
|
| 658 |
+
if metric_rows:
|
| 659 |
+
avg_cos = sum(r["cosine_full"] for r in metric_rows) / len(metric_rows)
|
| 660 |
+
avg_mse_red = sum(r["inactive_mse_reduction"] for r in metric_rows) / len(metric_rows)
|
| 661 |
+
else:
|
| 662 |
+
avg_cos = float("nan")
|
| 663 |
+
avg_mse_red = float("nan")
|
| 664 |
+
|
| 665 |
+
return {
|
| 666 |
+
"val": val_loss,
|
| 667 |
+
"ms": 1000.0 * elapsed / steps,
|
| 668 |
+
"cos": avg_cos,
|
| 669 |
+
"mse_red": avg_mse_red,
|
| 670 |
+
}
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
def main():
|
| 674 |
+
parser = argparse.ArgumentParser()
|
| 675 |
+
parser.add_argument("--steps", type=int, default=300)
|
| 676 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
| 677 |
+
parser.add_argument("--block_size", type=int, default=128)
|
| 678 |
+
parser.add_argument("--n_layer", type=int, default=4)
|
| 679 |
+
parser.add_argument("--n_head", type=int, default=8)
|
| 680 |
+
parser.add_argument("--n_embd", type=int, default=512)
|
| 681 |
+
parser.add_argument("--chunk_size", type=int, default=64)
|
| 682 |
+
parser.add_argument("--active_fraction", type=float, default=0.10)
|
| 683 |
+
parser.add_argument("--warmup_steps", type=int, default=25)
|
| 684 |
+
parser.add_argument("--policy", type=str, default="predicted_magnitude", choices=["predicted_magnitude", "random"])
|
| 685 |
+
parser.add_argument("--device", type=str, default="mps")
|
| 686 |
+
parser.add_argument("--benchmark_sync", action="store_true")
|
| 687 |
+
args = parser.parse_args()
|
| 688 |
+
|
| 689 |
+
modes = [
|
| 690 |
+
"dense",
|
| 691 |
+
"active_only",
|
| 692 |
+
"ema_inactive",
|
| 693 |
+
"knn_magnitude",
|
| 694 |
+
"graph_diffusion",
|
| 695 |
+
]
|
| 696 |
+
|
| 697 |
+
print(f"\nInactive-update prediction diagnostic")
|
| 698 |
+
print(f"device={args.device} steps={args.steps} d={args.n_embd} chunks={args.chunk_size}")
|
| 699 |
+
print(f"active_fraction={args.active_fraction} warmup={args.warmup_steps} policy={args.policy}\n")
|
| 700 |
+
print(f"{'mode':>18s} | {'val':>8s} | {'ms/step':>8s} | {'grad_cos':>8s} | {'inactive_mse+':>13s}")
|
| 701 |
+
print("-" * 70)
|
| 702 |
+
|
| 703 |
+
for mode in modes:
|
| 704 |
+
result = run_experiment(
|
| 705 |
+
mode=mode,
|
| 706 |
+
device=args.device,
|
| 707 |
+
steps=args.steps,
|
| 708 |
+
batch_size=args.batch_size,
|
| 709 |
+
block_size=args.block_size,
|
| 710 |
+
n_layer=args.n_layer,
|
| 711 |
+
n_head=args.n_head,
|
| 712 |
+
n_embd=args.n_embd,
|
| 713 |
+
chunk_size=args.chunk_size,
|
| 714 |
+
active_fraction=args.active_fraction,
|
| 715 |
+
warmup_steps=args.warmup_steps,
|
| 716 |
+
policy=args.policy,
|
| 717 |
+
benchmark_sync=args.benchmark_sync,
|
| 718 |
+
)
|
| 719 |
+
print(
|
| 720 |
+
f"{mode:>18s} | "
|
| 721 |
+
f"{result['val']:8.4f} | "
|
| 722 |
+
f"{result['ms']:8.2f} | "
|
| 723 |
+
f"{result['cos']:8.3f} | "
|
| 724 |
+
f"{result['mse_red']:13.3f}"
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
if __name__ == "__main__":
|
| 729 |
+
main()
|
experiments/sparse_transformer_v16_sensor_scheduler.py
ADDED
|
@@ -0,0 +1,677 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sparse Transformer v16: Sensor-Based Mask Scheduling.
|
| 3 |
+
|
| 4 |
+
v15 showed that directly hallucinating inactive gradient vectors was harmful.
|
| 5 |
+
v16 tests the safer next idea:
|
| 6 |
+
|
| 7 |
+
Use active chunks as sensors to choose which chunks receive real gradients next.
|
| 8 |
+
|
| 9 |
+
No inactive gradient is invented. In sparse modes, inactive chunks get zero gradient.
|
| 10 |
+
The only question is whether active chunk observations improve future mask selection.
|
| 11 |
+
|
| 12 |
+
Schedulers:
|
| 13 |
+
dense
|
| 14 |
+
Dense baseline.
|
| 15 |
+
|
| 16 |
+
ema_topk
|
| 17 |
+
Select top chunks by each chunk's own EMA gradient mass.
|
| 18 |
+
|
| 19 |
+
knn_scheduler
|
| 20 |
+
Use active chunks as sensors. Predict next-step inactive chunk mass from
|
| 21 |
+
historically correlated active chunks. Select next mask from that score.
|
| 22 |
+
|
| 23 |
+
graph_scheduler
|
| 24 |
+
Boundary-value style magnitude diffusion over a chunk similarity graph.
|
| 25 |
+
Active chunks are clamped to observed magnitudes. Inactive magnitudes are
|
| 26 |
+
interpolated and used to choose the next mask.
|
| 27 |
+
|
| 28 |
+
random
|
| 29 |
+
Random sparse-support control.
|
| 30 |
+
|
| 31 |
+
This is still a diagnostic/simulation script: it computes dense gradients so we can
|
| 32 |
+
measure oracle Jaccard/cosine, then installs only the selected active chunk gradients
|
| 33 |
+
for sparse training.
|
| 34 |
+
|
| 35 |
+
Run:
|
| 36 |
+
python3 sparse_transformer_v16_sensor_scheduler.py --device mps --benchmark_sync
|
| 37 |
+
|
| 38 |
+
Useful:
|
| 39 |
+
python3 sparse_transformer_v16_sensor_scheduler.py --device mps --steps 500 --n_embd 512
|
| 40 |
+
python3 sparse_transformer_v16_sensor_scheduler.py --device mps --steps 500 --n_embd 1024
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
from __future__ import annotations
|
| 44 |
+
|
| 45 |
+
import argparse
|
| 46 |
+
import math
|
| 47 |
+
import random
|
| 48 |
+
import time
|
| 49 |
+
from typing import Dict, List, Literal, Optional, Tuple
|
| 50 |
+
|
| 51 |
+
import torch
|
| 52 |
+
|
| 53 |
+
torch.set_num_threads(1)
|
| 54 |
+
import torch.nn as nn
|
| 55 |
+
import torch.nn.functional as F
|
| 56 |
+
|
| 57 |
+
Scheduler = Literal["dense", "ema_topk", "knn_scheduler", "graph_scheduler", "random"]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def sync_device(device: str) -> None:
|
| 61 |
+
if device == "cuda" and torch.cuda.is_available():
|
| 62 |
+
torch.cuda.synchronize()
|
| 63 |
+
elif device == "mps" and hasattr(torch, "mps"):
|
| 64 |
+
torch.mps.synchronize()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def set_seed(seed: int) -> None:
|
| 68 |
+
random.seed(seed)
|
| 69 |
+
torch.manual_seed(seed)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def make_cpu_generator(seed: int) -> torch.Generator:
|
| 73 |
+
gen = torch.Generator(device="cpu")
|
| 74 |
+
gen.manual_seed(seed)
|
| 75 |
+
return gen
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# -----------------------------
|
| 79 |
+
# Data
|
| 80 |
+
# -----------------------------
|
| 81 |
+
|
| 82 |
+
def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
|
| 83 |
+
rng = random.Random(seed)
|
| 84 |
+
words = [
|
| 85 |
+
"ada", "turing", "grace", "lovelace", "gradients",
|
| 86 |
+
"tokens", "circuits", "features", "boldly", "strangely",
|
| 87 |
+
"matrix", "attention", "kernel", "entropy", "signal",
|
| 88 |
+
]
|
| 89 |
+
return "\n".join(
|
| 90 |
+
" ".join(rng.choices(words, k=rng.randint(4, 10))) + "."
|
| 91 |
+
for _ in range(n_sentences)
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class CharCorpus:
|
| 96 |
+
def __init__(self, text: str, block_size: int, device: str):
|
| 97 |
+
chars = sorted(set(text))
|
| 98 |
+
self.stoi = {ch: i for i, ch in enumerate(chars)}
|
| 99 |
+
self.itos = {i: ch for ch, i in self.stoi.items()}
|
| 100 |
+
self.vocab_size = len(chars)
|
| 101 |
+
self.block_size = block_size
|
| 102 |
+
self.device = device
|
| 103 |
+
data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
|
| 104 |
+
self.train_data = data[: int(0.9 * len(data))]
|
| 105 |
+
self.val_data = data[int(0.9 * len(data)) :]
|
| 106 |
+
|
| 107 |
+
def get_batch(
|
| 108 |
+
self,
|
| 109 |
+
split: str,
|
| 110 |
+
batch_size: int,
|
| 111 |
+
generator: Optional[torch.Generator] = None,
|
| 112 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 113 |
+
data = self.train_data if split == "train" else self.val_data
|
| 114 |
+
ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator)
|
| 115 |
+
x = torch.stack([data[i : i + self.block_size] for i in ix])
|
| 116 |
+
y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
|
| 117 |
+
return x.to(self.device), y.to(self.device)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# -----------------------------
|
| 121 |
+
# Model
|
| 122 |
+
# -----------------------------
|
| 123 |
+
|
| 124 |
+
class SparseLinear(nn.Linear):
|
| 125 |
+
pass
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class CausalSelfAttention(nn.Module):
|
| 129 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 130 |
+
super().__init__()
|
| 131 |
+
assert n_embd % n_head == 0
|
| 132 |
+
self.n_head = n_head
|
| 133 |
+
self.head_dim = n_embd // n_head
|
| 134 |
+
self.c_attn = SparseLinear(n_embd, 3 * n_embd)
|
| 135 |
+
self.c_proj = SparseLinear(n_embd, n_embd)
|
| 136 |
+
self.dropout = nn.Dropout(dropout)
|
| 137 |
+
self.register_buffer(
|
| 138 |
+
"mask",
|
| 139 |
+
torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size),
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 143 |
+
B, T, C = x.shape
|
| 144 |
+
qkv = self.c_attn(x)
|
| 145 |
+
q, k, v = qkv.split(C, dim=2)
|
| 146 |
+
|
| 147 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 148 |
+
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 149 |
+
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 150 |
+
|
| 151 |
+
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 152 |
+
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
|
| 153 |
+
att = F.softmax(att, dim=-1)
|
| 154 |
+
att = self.dropout(att)
|
| 155 |
+
|
| 156 |
+
y = att @ v
|
| 157 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 158 |
+
return self.c_proj(y)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class FeedForward(nn.Module):
|
| 162 |
+
def __init__(self, n_embd: int, dropout: float):
|
| 163 |
+
super().__init__()
|
| 164 |
+
self.c_fc = SparseLinear(n_embd, 4 * n_embd)
|
| 165 |
+
self.c_proj = SparseLinear(4 * n_embd, n_embd)
|
| 166 |
+
self.dropout = nn.Dropout(dropout)
|
| 167 |
+
|
| 168 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 169 |
+
return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class Block(nn.Module):
|
| 173 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 174 |
+
super().__init__()
|
| 175 |
+
self.ln1 = nn.LayerNorm(n_embd)
|
| 176 |
+
self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
|
| 177 |
+
self.ln2 = nn.LayerNorm(n_embd)
|
| 178 |
+
self.mlp = FeedForward(n_embd, dropout)
|
| 179 |
+
|
| 180 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 181 |
+
x = x + self.attn(self.ln1(x))
|
| 182 |
+
x = x + self.mlp(self.ln2(x))
|
| 183 |
+
return x
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class MiniGPT(nn.Module):
|
| 187 |
+
def __init__(
|
| 188 |
+
self,
|
| 189 |
+
vocab_size: int,
|
| 190 |
+
block_size: int,
|
| 191 |
+
n_layer: int,
|
| 192 |
+
n_head: int,
|
| 193 |
+
n_embd: int,
|
| 194 |
+
dropout: float,
|
| 195 |
+
):
|
| 196 |
+
super().__init__()
|
| 197 |
+
self.block_size = block_size
|
| 198 |
+
self.tok_emb = nn.Embedding(vocab_size, n_embd)
|
| 199 |
+
self.pos_emb = nn.Embedding(block_size, n_embd)
|
| 200 |
+
self.blocks = nn.Sequential(
|
| 201 |
+
*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]
|
| 202 |
+
)
|
| 203 |
+
self.ln_f = nn.LayerNorm(n_embd)
|
| 204 |
+
self.lm_head = nn.Linear(n_embd, vocab_size)
|
| 205 |
+
|
| 206 |
+
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
|
| 207 |
+
B, T = idx.shape
|
| 208 |
+
pos = torch.arange(T, device=idx.device)
|
| 209 |
+
x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
|
| 210 |
+
x = self.blocks(x)
|
| 211 |
+
x = self.ln_f(x)
|
| 212 |
+
logits = self.lm_head(x)
|
| 213 |
+
|
| 214 |
+
loss = None
|
| 215 |
+
if targets is not None:
|
| 216 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
| 217 |
+
return logits, loss
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def get_sparse_linears(model: nn.Module) -> List[SparseLinear]:
|
| 221 |
+
return [m for m in model.modules() if isinstance(m, SparseLinear)]
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# -----------------------------
|
| 225 |
+
# Chunk map and scheduler
|
| 226 |
+
# -----------------------------
|
| 227 |
+
|
| 228 |
+
class ChunkScheduler:
|
| 229 |
+
def __init__(
|
| 230 |
+
self,
|
| 231 |
+
model: nn.Module,
|
| 232 |
+
chunk_size: int,
|
| 233 |
+
active_fraction: float,
|
| 234 |
+
device: str,
|
| 235 |
+
scheduler: Scheduler,
|
| 236 |
+
mass_beta: float = 0.95,
|
| 237 |
+
):
|
| 238 |
+
self.model = model
|
| 239 |
+
self.chunk_size = chunk_size
|
| 240 |
+
self.active_fraction = active_fraction
|
| 241 |
+
self.device = device
|
| 242 |
+
self.scheduler = scheduler
|
| 243 |
+
self.mass_beta = mass_beta
|
| 244 |
+
|
| 245 |
+
self.linears = get_sparse_linears(model)
|
| 246 |
+
self.module_to_chunk_ids: Dict[nn.Module, torch.Tensor] = {}
|
| 247 |
+
self.chunk_to_module_local: List[Tuple[nn.Module, int]] = []
|
| 248 |
+
|
| 249 |
+
offset = 0
|
| 250 |
+
for m in self.linears:
|
| 251 |
+
assert m.out_features % chunk_size == 0, (
|
| 252 |
+
f"out_features {m.out_features} not divisible by chunk_size {chunk_size}"
|
| 253 |
+
)
|
| 254 |
+
n_chunks = m.out_features // chunk_size
|
| 255 |
+
ids = torch.arange(offset, offset + n_chunks, device=device)
|
| 256 |
+
self.module_to_chunk_ids[m] = ids
|
| 257 |
+
for local_c in range(n_chunks):
|
| 258 |
+
self.chunk_to_module_local.append((m, local_c))
|
| 259 |
+
offset += n_chunks
|
| 260 |
+
|
| 261 |
+
self.n_chunks = offset
|
| 262 |
+
self.predicted_mass = torch.zeros(self.n_chunks, device=device)
|
| 263 |
+
self.mass_history: List[torch.Tensor] = []
|
| 264 |
+
|
| 265 |
+
self.current_mask = torch.ones(self.n_chunks, dtype=torch.bool, device=device)
|
| 266 |
+
self.next_scores = torch.zeros(self.n_chunks, device=device)
|
| 267 |
+
|
| 268 |
+
self.prev_mask: Optional[torch.Tensor] = None
|
| 269 |
+
self.similarity: Optional[torch.Tensor] = None
|
| 270 |
+
|
| 271 |
+
def k_active(self) -> int:
|
| 272 |
+
return max(1, int(self.active_fraction * self.n_chunks))
|
| 273 |
+
|
| 274 |
+
def choose_mask(self, step: int, warmup_steps: int) -> torch.Tensor:
|
| 275 |
+
if self.scheduler == "dense" or step < warmup_steps:
|
| 276 |
+
self.current_mask = torch.ones(self.n_chunks, dtype=torch.bool, device=self.device)
|
| 277 |
+
return self.current_mask
|
| 278 |
+
|
| 279 |
+
k = self.k_active()
|
| 280 |
+
mask = torch.zeros(self.n_chunks, dtype=torch.bool, device=self.device)
|
| 281 |
+
|
| 282 |
+
if self.scheduler == "random":
|
| 283 |
+
idx = torch.randperm(self.n_chunks, device=self.device)[:k]
|
| 284 |
+
|
| 285 |
+
elif self.scheduler == "ema_topk":
|
| 286 |
+
scores = self.predicted_mass + 1e-9 * torch.rand_like(self.predicted_mass)
|
| 287 |
+
idx = torch.topk(scores, k=k).indices
|
| 288 |
+
|
| 289 |
+
elif self.scheduler in ("knn_scheduler", "graph_scheduler"):
|
| 290 |
+
# next_scores are computed from the previous step's active sensors.
|
| 291 |
+
# If unavailable, fall back to EMA.
|
| 292 |
+
base = self.next_scores
|
| 293 |
+
if torch.count_nonzero(base).item() == 0:
|
| 294 |
+
base = self.predicted_mass
|
| 295 |
+
scores = base + 1e-9 * torch.rand_like(base)
|
| 296 |
+
idx = torch.topk(scores, k=k).indices
|
| 297 |
+
|
| 298 |
+
else:
|
| 299 |
+
raise ValueError(f"Unknown scheduler: {self.scheduler}")
|
| 300 |
+
|
| 301 |
+
mask[idx] = True
|
| 302 |
+
self.current_mask = mask
|
| 303 |
+
return mask
|
| 304 |
+
|
| 305 |
+
@torch.no_grad()
|
| 306 |
+
def chunk_gradient_vectors(self) -> List[torch.Tensor]:
|
| 307 |
+
vecs: List[torch.Tensor] = []
|
| 308 |
+
for m, local_c in self.chunk_to_module_local:
|
| 309 |
+
start = local_c * self.chunk_size
|
| 310 |
+
end = (local_c + 1) * self.chunk_size
|
| 311 |
+
|
| 312 |
+
parts = []
|
| 313 |
+
if m.weight.grad is None:
|
| 314 |
+
parts.append(torch.zeros_like(m.weight[start:end]).flatten())
|
| 315 |
+
else:
|
| 316 |
+
parts.append(m.weight.grad[start:end].detach().flatten())
|
| 317 |
+
|
| 318 |
+
if m.bias is not None:
|
| 319 |
+
if m.bias.grad is None:
|
| 320 |
+
parts.append(torch.zeros_like(m.bias[start:end]).flatten())
|
| 321 |
+
else:
|
| 322 |
+
parts.append(m.bias.grad[start:end].detach().flatten())
|
| 323 |
+
|
| 324 |
+
vecs.append(torch.cat(parts))
|
| 325 |
+
return vecs
|
| 326 |
+
|
| 327 |
+
@torch.no_grad()
|
| 328 |
+
def chunk_masses_from_vecs(self, vecs: List[torch.Tensor]) -> torch.Tensor:
|
| 329 |
+
return torch.stack([v.norm() for v in vecs]).to(self.device)
|
| 330 |
+
|
| 331 |
+
@torch.no_grad()
|
| 332 |
+
def update_from_observed(
|
| 333 |
+
self,
|
| 334 |
+
active_mask: torch.Tensor,
|
| 335 |
+
true_masses: torch.Tensor,
|
| 336 |
+
step: int,
|
| 337 |
+
warmup_steps: int,
|
| 338 |
+
) -> None:
|
| 339 |
+
observed = active_mask
|
| 340 |
+
|
| 341 |
+
never_seen = observed & (self.predicted_mass == 0)
|
| 342 |
+
already_seen = observed & ~never_seen
|
| 343 |
+
|
| 344 |
+
self.predicted_mass[never_seen] = true_masses[never_seen]
|
| 345 |
+
self.predicted_mass[already_seen] = (
|
| 346 |
+
self.mass_beta * self.predicted_mass[already_seen]
|
| 347 |
+
+ (1.0 - self.mass_beta) * true_masses[already_seen]
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
# During warmup we store dense mass histories to learn the similarity graph.
|
| 351 |
+
if step < warmup_steps:
|
| 352 |
+
self.mass_history.append(true_masses.detach().clone())
|
| 353 |
+
max_hist = 128
|
| 354 |
+
if len(self.mass_history) > max_hist:
|
| 355 |
+
self.mass_history = self.mass_history[-max_hist:]
|
| 356 |
+
|
| 357 |
+
if len(self.mass_history) >= 8:
|
| 358 |
+
self.similarity = self.build_similarity()
|
| 359 |
+
|
| 360 |
+
# Compute next_scores from current active observations.
|
| 361 |
+
if self.scheduler == "knn_scheduler":
|
| 362 |
+
self.next_scores = self.knn_scores(active_mask, true_masses)
|
| 363 |
+
elif self.scheduler == "graph_scheduler":
|
| 364 |
+
self.next_scores = self.diffusion_scores(active_mask, true_masses)
|
| 365 |
+
else:
|
| 366 |
+
self.next_scores = self.predicted_mass.clone()
|
| 367 |
+
|
| 368 |
+
def layer_allowed_mask(self) -> torch.Tensor:
|
| 369 |
+
allowed = torch.zeros((self.n_chunks, self.n_chunks), dtype=torch.bool, device=self.device)
|
| 370 |
+
for _, ids in self.module_to_chunk_ids.items():
|
| 371 |
+
allowed |= ids[:, None].eq(ids[None, :]) # placeholder overwritten below
|
| 372 |
+
|
| 373 |
+
allowed.zero_()
|
| 374 |
+
for _, ids in self.module_to_chunk_ids.items():
|
| 375 |
+
allowed[ids[:, None], ids[None, :]] = True
|
| 376 |
+
return allowed
|
| 377 |
+
|
| 378 |
+
def build_similarity(self) -> torch.Tensor:
|
| 379 |
+
H = torch.stack(self.mass_history, dim=0) # [history, chunks]
|
| 380 |
+
H = H - H.mean(dim=0, keepdim=True)
|
| 381 |
+
H = H / (H.std(dim=0, keepdim=True) + 1e-6)
|
| 382 |
+
|
| 383 |
+
S = (H.T @ H) / max(1, H.shape[0] - 1)
|
| 384 |
+
S = torch.clamp(S, min=0.0)
|
| 385 |
+
S.fill_diagonal_(0.0)
|
| 386 |
+
|
| 387 |
+
# Keep only within-layer similarities. Cross-layer correlation is too easy
|
| 388 |
+
# to overfit in this tiny diagnostic.
|
| 389 |
+
allowed = torch.zeros_like(S, dtype=torch.bool)
|
| 390 |
+
for _, ids in self.module_to_chunk_ids.items():
|
| 391 |
+
allowed[ids[:, None], ids[None, :]] = True
|
| 392 |
+
S = torch.where(allowed, S, torch.zeros_like(S))
|
| 393 |
+
return S
|
| 394 |
+
|
| 395 |
+
def knn_scores(self, active_mask: torch.Tensor, true_masses: torch.Tensor, k_neighbors: int = 3) -> torch.Tensor:
|
| 396 |
+
if self.similarity is None:
|
| 397 |
+
return self.predicted_mass.clone()
|
| 398 |
+
|
| 399 |
+
S = self.similarity
|
| 400 |
+
scores = self.predicted_mass.clone()
|
| 401 |
+
scores[active_mask] = true_masses[active_mask]
|
| 402 |
+
|
| 403 |
+
active_idx = torch.nonzero(active_mask, as_tuple=False).flatten()
|
| 404 |
+
inactive_idx = torch.nonzero(~active_mask, as_tuple=False).flatten()
|
| 405 |
+
|
| 406 |
+
if active_idx.numel() == 0:
|
| 407 |
+
return scores
|
| 408 |
+
|
| 409 |
+
for i in inactive_idx.tolist():
|
| 410 |
+
weights = S[i, active_idx]
|
| 411 |
+
if weights.sum() <= 1e-12:
|
| 412 |
+
continue
|
| 413 |
+
kk = min(k_neighbors, weights.numel())
|
| 414 |
+
top = torch.topk(weights, k=kk)
|
| 415 |
+
w = top.values
|
| 416 |
+
aidx = active_idx[top.indices]
|
| 417 |
+
scores[i] = (w * true_masses[aidx]).sum() / (w.sum() + 1e-12)
|
| 418 |
+
|
| 419 |
+
return scores
|
| 420 |
+
|
| 421 |
+
def diffusion_scores(
|
| 422 |
+
self,
|
| 423 |
+
active_mask: torch.Tensor,
|
| 424 |
+
true_masses: torch.Tensor,
|
| 425 |
+
diffusion_steps: int = 8,
|
| 426 |
+
alpha: float = 0.7,
|
| 427 |
+
) -> torch.Tensor:
|
| 428 |
+
if self.similarity is None:
|
| 429 |
+
return self.predicted_mass.clone()
|
| 430 |
+
|
| 431 |
+
S = self.similarity
|
| 432 |
+
W = S / (S.sum(dim=1, keepdim=True) + 1e-12)
|
| 433 |
+
|
| 434 |
+
scores = self.predicted_mass.clone()
|
| 435 |
+
scores[active_mask] = true_masses[active_mask]
|
| 436 |
+
|
| 437 |
+
for _ in range(diffusion_steps):
|
| 438 |
+
proposal = W @ scores
|
| 439 |
+
scores = alpha * proposal + (1.0 - alpha) * scores
|
| 440 |
+
scores[active_mask] = true_masses[active_mask]
|
| 441 |
+
|
| 442 |
+
return torch.clamp(scores, min=0.0)
|
| 443 |
+
|
| 444 |
+
def oracle_topk_mask(self, true_masses: torch.Tensor) -> torch.Tensor:
|
| 445 |
+
k = self.k_active()
|
| 446 |
+
mask = torch.zeros(self.n_chunks, dtype=torch.bool, device=self.device)
|
| 447 |
+
mask[torch.topk(true_masses, k=k).indices] = True
|
| 448 |
+
return mask
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
# -----------------------------
|
| 452 |
+
# Gradient installation and metrics
|
| 453 |
+
# -----------------------------
|
| 454 |
+
|
| 455 |
+
@torch.no_grad()
|
| 456 |
+
def install_active_only_grads(sched: ChunkScheduler, active_mask: torch.Tensor) -> None:
|
| 457 |
+
if sched.scheduler == "dense":
|
| 458 |
+
return
|
| 459 |
+
|
| 460 |
+
for m, ids in sched.module_to_chunk_ids.items():
|
| 461 |
+
local_active = active_mask[ids]
|
| 462 |
+
if m.weight.grad is not None:
|
| 463 |
+
for local_c, is_active in enumerate(local_active.tolist()):
|
| 464 |
+
if not is_active:
|
| 465 |
+
start = local_c * sched.chunk_size
|
| 466 |
+
end = (local_c + 1) * sched.chunk_size
|
| 467 |
+
m.weight.grad[start:end].zero_()
|
| 468 |
+
|
| 469 |
+
if m.bias is not None and m.bias.grad is not None:
|
| 470 |
+
for local_c, is_active in enumerate(local_active.tolist()):
|
| 471 |
+
if not is_active:
|
| 472 |
+
start = local_c * sched.chunk_size
|
| 473 |
+
end = (local_c + 1) * sched.chunk_size
|
| 474 |
+
m.bias.grad[start:end].zero_()
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def dense_cosine_active_only(vecs: List[torch.Tensor], active_mask: torch.Tensor) -> float:
|
| 478 |
+
true = torch.cat([v.flatten() for v in vecs])
|
| 479 |
+
approx_parts = []
|
| 480 |
+
for i, v in enumerate(vecs):
|
| 481 |
+
approx_parts.append(v.flatten() if bool(active_mask[i]) else torch.zeros_like(v).flatten())
|
| 482 |
+
approx = torch.cat(approx_parts)
|
| 483 |
+
return float(F.cosine_similarity(true, approx, dim=0).item())
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def jaccard(a: torch.Tensor, b: torch.Tensor) -> float:
|
| 487 |
+
inter = (a & b).sum().float()
|
| 488 |
+
union = (a | b).sum().float()
|
| 489 |
+
return float((inter / torch.clamp(union, min=1.0)).item())
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
class SimpleAdam:
|
| 493 |
+
def __init__(self, model: nn.Module, lr: float = 3e-4):
|
| 494 |
+
self.model = model
|
| 495 |
+
self.lr = lr
|
| 496 |
+
self.state: Dict[torch.nn.Parameter, Dict[str, torch.Tensor]] = {}
|
| 497 |
+
|
| 498 |
+
def zero_grad(self):
|
| 499 |
+
for p in self.model.parameters():
|
| 500 |
+
p.grad = None
|
| 501 |
+
|
| 502 |
+
@torch.no_grad()
|
| 503 |
+
def step(self):
|
| 504 |
+
for p in self.model.parameters():
|
| 505 |
+
if p.grad is None:
|
| 506 |
+
continue
|
| 507 |
+
if p not in self.state:
|
| 508 |
+
self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
|
| 509 |
+
m = self.state[p]["m"]
|
| 510 |
+
v = self.state[p]["v"]
|
| 511 |
+
m.mul_(0.9).add_(p.grad, alpha=0.1)
|
| 512 |
+
v.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
|
| 513 |
+
p.sub_(m / (torch.sqrt(v) + 1e-8), alpha=self.lr)
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def evaluate(model: nn.Module, corpus: CharCorpus, batch_size: int, seed: int) -> float:
|
| 517 |
+
model.eval()
|
| 518 |
+
with torch.no_grad():
|
| 519 |
+
x, y = corpus.get_batch("val", batch_size, generator=make_cpu_generator(seed))
|
| 520 |
+
_, loss = model(x, y)
|
| 521 |
+
model.train()
|
| 522 |
+
return float(loss.item())
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
def run_experiment(
|
| 526 |
+
scheduler_name: Scheduler,
|
| 527 |
+
device: str,
|
| 528 |
+
steps: int,
|
| 529 |
+
batch_size: int,
|
| 530 |
+
block_size: int,
|
| 531 |
+
n_layer: int,
|
| 532 |
+
n_head: int,
|
| 533 |
+
n_embd: int,
|
| 534 |
+
chunk_size: int,
|
| 535 |
+
active_fraction: float,
|
| 536 |
+
warmup_steps: int,
|
| 537 |
+
benchmark_sync: bool,
|
| 538 |
+
) -> Dict[str, float]:
|
| 539 |
+
set_seed(42)
|
| 540 |
+
|
| 541 |
+
corpus = CharCorpus(make_synthetic_corpus(), block_size, device)
|
| 542 |
+
model = MiniGPT(corpus.vocab_size, block_size, n_layer, n_head, n_embd, 0.0).to(device)
|
| 543 |
+
opt = SimpleAdam(model, lr=3e-4)
|
| 544 |
+
sched = ChunkScheduler(
|
| 545 |
+
model=model,
|
| 546 |
+
chunk_size=chunk_size,
|
| 547 |
+
active_fraction=active_fraction,
|
| 548 |
+
device=device,
|
| 549 |
+
scheduler=scheduler_name,
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
metric_rows = []
|
| 553 |
+
|
| 554 |
+
if benchmark_sync:
|
| 555 |
+
sync_device(device)
|
| 556 |
+
t0 = time.perf_counter()
|
| 557 |
+
|
| 558 |
+
for step in range(steps):
|
| 559 |
+
x, y = corpus.get_batch("train", batch_size, generator=make_cpu_generator(step))
|
| 560 |
+
|
| 561 |
+
active_mask = sched.choose_mask(step=step, warmup_steps=warmup_steps)
|
| 562 |
+
|
| 563 |
+
opt.zero_grad()
|
| 564 |
+
_, loss = model(x, y)
|
| 565 |
+
loss.backward()
|
| 566 |
+
|
| 567 |
+
vecs = sched.chunk_gradient_vectors()
|
| 568 |
+
masses = sched.chunk_masses_from_vecs(vecs)
|
| 569 |
+
|
| 570 |
+
if step >= warmup_steps and scheduler_name != "dense":
|
| 571 |
+
oracle = sched.oracle_topk_mask(masses)
|
| 572 |
+
row = {
|
| 573 |
+
"cos": dense_cosine_active_only(vecs, active_mask),
|
| 574 |
+
"jacc": jaccard(active_mask, oracle),
|
| 575 |
+
"stable": jaccard(active_mask, sched.prev_mask) if sched.prev_mask is not None else 0.0,
|
| 576 |
+
"val": evaluate(model, corpus, batch_size, seed=10_000 + step) if step % 50 == 0 else float("nan"),
|
| 577 |
+
}
|
| 578 |
+
metric_rows.append(row)
|
| 579 |
+
|
| 580 |
+
install_active_only_grads(sched, active_mask)
|
| 581 |
+
|
| 582 |
+
# Important: update scheduler from the active observations only.
|
| 583 |
+
# Dense gradients exist for diagnostics, but unselected chunks should not
|
| 584 |
+
# teach the sparse scheduler after warmup.
|
| 585 |
+
observed_for_scheduler = active_mask if step >= warmup_steps else torch.ones_like(active_mask)
|
| 586 |
+
sched.update_from_observed(
|
| 587 |
+
active_mask=observed_for_scheduler,
|
| 588 |
+
true_masses=masses,
|
| 589 |
+
step=step,
|
| 590 |
+
warmup_steps=warmup_steps,
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
sched.prev_mask = active_mask.clone()
|
| 594 |
+
|
| 595 |
+
opt.step()
|
| 596 |
+
|
| 597 |
+
if benchmark_sync:
|
| 598 |
+
sync_device(device)
|
| 599 |
+
elapsed = time.perf_counter() - t0
|
| 600 |
+
|
| 601 |
+
val_loss = evaluate(model, corpus, batch_size, seed=12345)
|
| 602 |
+
|
| 603 |
+
if metric_rows:
|
| 604 |
+
avg_cos = sum(r["cos"] for r in metric_rows) / len(metric_rows)
|
| 605 |
+
avg_jacc = sum(r["jacc"] for r in metric_rows) / len(metric_rows)
|
| 606 |
+
avg_stable = sum(r["stable"] for r in metric_rows) / len(metric_rows)
|
| 607 |
+
else:
|
| 608 |
+
avg_cos = float("nan")
|
| 609 |
+
avg_jacc = float("nan")
|
| 610 |
+
avg_stable = float("nan")
|
| 611 |
+
|
| 612 |
+
return {
|
| 613 |
+
"val": val_loss,
|
| 614 |
+
"ms": 1000.0 * elapsed / steps,
|
| 615 |
+
"cos": avg_cos,
|
| 616 |
+
"jacc": avg_jacc,
|
| 617 |
+
"stable": avg_stable,
|
| 618 |
+
}
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
def main() -> None:
|
| 622 |
+
parser = argparse.ArgumentParser()
|
| 623 |
+
parser.add_argument("--steps", type=int, default=500)
|
| 624 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
| 625 |
+
parser.add_argument("--block_size", type=int, default=128)
|
| 626 |
+
parser.add_argument("--n_layer", type=int, default=4)
|
| 627 |
+
parser.add_argument("--n_head", type=int, default=8)
|
| 628 |
+
parser.add_argument("--n_embd", type=int, default=512)
|
| 629 |
+
parser.add_argument("--chunk_size", type=int, default=64)
|
| 630 |
+
parser.add_argument("--active_fraction", type=float, default=0.10)
|
| 631 |
+
parser.add_argument("--warmup_steps", type=int, default=25)
|
| 632 |
+
parser.add_argument("--device", type=str, default="mps")
|
| 633 |
+
parser.add_argument("--benchmark_sync", action="store_true")
|
| 634 |
+
args = parser.parse_args()
|
| 635 |
+
|
| 636 |
+
schedulers: List[Scheduler] = [
|
| 637 |
+
"dense",
|
| 638 |
+
"ema_topk",
|
| 639 |
+
"knn_scheduler",
|
| 640 |
+
"graph_scheduler",
|
| 641 |
+
"random",
|
| 642 |
+
]
|
| 643 |
+
|
| 644 |
+
print("\nSensor-based mask scheduling diagnostic")
|
| 645 |
+
print(f"device={args.device} steps={args.steps} d={args.n_embd} chunks={args.chunk_size}")
|
| 646 |
+
print(f"active_fraction={args.active_fraction} warmup={args.warmup_steps}\n")
|
| 647 |
+
print(f"{'scheduler':>18s} | {'val':>8s} | {'ms/step':>8s} | {'grad_cos':>8s} | {'jacc':>8s} | {'stable':>8s}")
|
| 648 |
+
print("-" * 78)
|
| 649 |
+
|
| 650 |
+
for sched_name in schedulers:
|
| 651 |
+
result = run_experiment(
|
| 652 |
+
scheduler_name=sched_name,
|
| 653 |
+
device=args.device,
|
| 654 |
+
steps=args.steps,
|
| 655 |
+
batch_size=args.batch_size,
|
| 656 |
+
block_size=args.block_size,
|
| 657 |
+
n_layer=args.n_layer,
|
| 658 |
+
n_head=args.n_head,
|
| 659 |
+
n_embd=args.n_embd,
|
| 660 |
+
chunk_size=args.chunk_size,
|
| 661 |
+
active_fraction=args.active_fraction,
|
| 662 |
+
warmup_steps=args.warmup_steps,
|
| 663 |
+
benchmark_sync=args.benchmark_sync,
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
print(
|
| 667 |
+
f"{sched_name:>18s} | "
|
| 668 |
+
f"{result['val']:8.4f} | "
|
| 669 |
+
f"{result['ms']:8.2f} | "
|
| 670 |
+
f"{result['cos']:8.3f} | "
|
| 671 |
+
f"{result['jacc']:8.3f} | "
|
| 672 |
+
f"{result['stable']:8.3f}"
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
if __name__ == "__main__":
|
| 677 |
+
main()
|
experiments/sparse_transformer_v17_radar_scheduler.py
ADDED
|
@@ -0,0 +1,725 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sparse Transformer v17: Radar Scheduler Diagnostic.
|
| 3 |
+
|
| 4 |
+
v16 showed:
|
| 5 |
+
- Directly predicting inactive gradient vectors is harmful.
|
| 6 |
+
- Using active chunks as sensors to schedule the next active mask works.
|
| 7 |
+
- KNN/graph sensors improve oracle overlap and gradient cosine, but churn masks.
|
| 8 |
+
- EMA is stable, but can be blind.
|
| 9 |
+
|
| 10 |
+
v17 tests the fusion:
|
| 11 |
+
|
| 12 |
+
radar_score = alpha * normalized_ema_mass
|
| 13 |
+
+ (1 - alpha) * normalized_sensor_score
|
| 14 |
+
|
| 15 |
+
where sensor_score is either:
|
| 16 |
+
- KNN over a learned chunk-mass correlation graph
|
| 17 |
+
- graph diffusion / boundary interpolation over that graph
|
| 18 |
+
|
| 19 |
+
This is still a diagnostic script. It computes dense gradients so we can measure:
|
| 20 |
+
- oracle Jaccard
|
| 21 |
+
- active-only full-gradient cosine
|
| 22 |
+
- mask stability
|
| 23 |
+
- validation loss after sparse active-only updates
|
| 24 |
+
|
| 25 |
+
No inactive gradients are invented. In sparse modes, inactive chunks get zeroed.
|
| 26 |
+
|
| 27 |
+
Run:
|
| 28 |
+
python3 sparse_transformer_v17_radar_scheduler.py --device mps --benchmark_sync
|
| 29 |
+
|
| 30 |
+
Useful:
|
| 31 |
+
python3 sparse_transformer_v17_radar_scheduler.py --device mps --steps 500 --n_embd 512
|
| 32 |
+
python3 sparse_transformer_v17_radar_scheduler.py --device mps --steps 500 --n_embd 1024
|
| 33 |
+
python3 sparse_transformer_v17_radar_scheduler.py --device mps --steps 500 --n_embd 1024 --alphas 0.25 0.5 0.75 0.9
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
from __future__ import annotations
|
| 37 |
+
|
| 38 |
+
import argparse
|
| 39 |
+
import math
|
| 40 |
+
import random
|
| 41 |
+
import time
|
| 42 |
+
from typing import Dict, List, Literal, Optional, Tuple
|
| 43 |
+
|
| 44 |
+
import torch
|
| 45 |
+
|
| 46 |
+
torch.set_num_threads(1)
|
| 47 |
+
import torch.nn as nn
|
| 48 |
+
import torch.nn.functional as F
|
| 49 |
+
|
| 50 |
+
Scheduler = Literal[
|
| 51 |
+
"dense",
|
| 52 |
+
"ema_topk",
|
| 53 |
+
"knn_scheduler",
|
| 54 |
+
"graph_scheduler",
|
| 55 |
+
"radar_knn",
|
| 56 |
+
"radar_graph",
|
| 57 |
+
"random",
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def sync_device(device: str) -> None:
|
| 62 |
+
if device == "cuda" and torch.cuda.is_available():
|
| 63 |
+
torch.cuda.synchronize()
|
| 64 |
+
elif device == "mps" and hasattr(torch, "mps"):
|
| 65 |
+
torch.mps.synchronize()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def set_seed(seed: int) -> None:
|
| 69 |
+
random.seed(seed)
|
| 70 |
+
torch.manual_seed(seed)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def make_cpu_generator(seed: int) -> torch.Generator:
|
| 74 |
+
gen = torch.Generator(device="cpu")
|
| 75 |
+
gen.manual_seed(seed)
|
| 76 |
+
return gen
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def normalize_scores(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
|
| 80 |
+
"""Robust [0, 1] normalization.
|
| 81 |
+
|
| 82 |
+
We avoid z-score because heavy tails are the signal, not necessarily noise.
|
| 83 |
+
"""
|
| 84 |
+
x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
|
| 85 |
+
lo = x.min()
|
| 86 |
+
hi = x.max()
|
| 87 |
+
if (hi - lo) <= eps:
|
| 88 |
+
return torch.zeros_like(x)
|
| 89 |
+
return (x - lo) / (hi - lo + eps)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# -----------------------------
|
| 93 |
+
# Data
|
| 94 |
+
# -----------------------------
|
| 95 |
+
|
| 96 |
+
def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
|
| 97 |
+
rng = random.Random(seed)
|
| 98 |
+
words = [
|
| 99 |
+
"ada", "turing", "grace", "lovelace", "gradients",
|
| 100 |
+
"tokens", "circuits", "features", "boldly", "strangely",
|
| 101 |
+
"matrix", "attention", "kernel", "entropy", "signal",
|
| 102 |
+
]
|
| 103 |
+
return "\n".join(
|
| 104 |
+
" ".join(rng.choices(words, k=rng.randint(4, 10))) + "."
|
| 105 |
+
for _ in range(n_sentences)
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class CharCorpus:
|
| 110 |
+
def __init__(self, text: str, block_size: int, device: str):
|
| 111 |
+
chars = sorted(set(text))
|
| 112 |
+
self.stoi = {ch: i for i, ch in enumerate(chars)}
|
| 113 |
+
self.itos = {i: ch for ch, i in self.stoi.items()}
|
| 114 |
+
self.vocab_size = len(chars)
|
| 115 |
+
self.block_size = block_size
|
| 116 |
+
self.device = device
|
| 117 |
+
data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
|
| 118 |
+
self.train_data = data[: int(0.9 * len(data))]
|
| 119 |
+
self.val_data = data[int(0.9 * len(data)) :]
|
| 120 |
+
|
| 121 |
+
def get_batch(
|
| 122 |
+
self,
|
| 123 |
+
split: str,
|
| 124 |
+
batch_size: int,
|
| 125 |
+
generator: Optional[torch.Generator] = None,
|
| 126 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 127 |
+
data = self.train_data if split == "train" else self.val_data
|
| 128 |
+
ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator)
|
| 129 |
+
x = torch.stack([data[i : i + self.block_size] for i in ix])
|
| 130 |
+
y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
|
| 131 |
+
return x.to(self.device), y.to(self.device)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# -----------------------------
|
| 135 |
+
# Model
|
| 136 |
+
# -----------------------------
|
| 137 |
+
|
| 138 |
+
class SparseLinear(nn.Linear):
|
| 139 |
+
pass
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class CausalSelfAttention(nn.Module):
|
| 143 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 144 |
+
super().__init__()
|
| 145 |
+
assert n_embd % n_head == 0
|
| 146 |
+
self.n_head = n_head
|
| 147 |
+
self.head_dim = n_embd // n_head
|
| 148 |
+
self.c_attn = SparseLinear(n_embd, 3 * n_embd)
|
| 149 |
+
self.c_proj = SparseLinear(n_embd, n_embd)
|
| 150 |
+
self.dropout = nn.Dropout(dropout)
|
| 151 |
+
self.register_buffer(
|
| 152 |
+
"mask",
|
| 153 |
+
torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 157 |
+
B, T, C = x.shape
|
| 158 |
+
qkv = self.c_attn(x)
|
| 159 |
+
q, k, v = qkv.split(C, dim=2)
|
| 160 |
+
|
| 161 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 162 |
+
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 163 |
+
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 164 |
+
|
| 165 |
+
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 166 |
+
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
|
| 167 |
+
att = F.softmax(att, dim=-1)
|
| 168 |
+
att = self.dropout(att)
|
| 169 |
+
|
| 170 |
+
y = att @ v
|
| 171 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 172 |
+
return self.c_proj(y)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class FeedForward(nn.Module):
|
| 176 |
+
def __init__(self, n_embd: int, dropout: float):
|
| 177 |
+
super().__init__()
|
| 178 |
+
self.c_fc = SparseLinear(n_embd, 4 * n_embd)
|
| 179 |
+
self.c_proj = SparseLinear(4 * n_embd, n_embd)
|
| 180 |
+
self.dropout = nn.Dropout(dropout)
|
| 181 |
+
|
| 182 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 183 |
+
return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class Block(nn.Module):
|
| 187 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.ln1 = nn.LayerNorm(n_embd)
|
| 190 |
+
self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
|
| 191 |
+
self.ln2 = nn.LayerNorm(n_embd)
|
| 192 |
+
self.mlp = FeedForward(n_embd, dropout)
|
| 193 |
+
|
| 194 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 195 |
+
x = x + self.attn(self.ln1(x))
|
| 196 |
+
x = x + self.mlp(self.ln2(x))
|
| 197 |
+
return x
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class MiniGPT(nn.Module):
|
| 201 |
+
def __init__(
|
| 202 |
+
self,
|
| 203 |
+
vocab_size: int,
|
| 204 |
+
block_size: int,
|
| 205 |
+
n_layer: int,
|
| 206 |
+
n_head: int,
|
| 207 |
+
n_embd: int,
|
| 208 |
+
dropout: float,
|
| 209 |
+
):
|
| 210 |
+
super().__init__()
|
| 211 |
+
self.block_size = block_size
|
| 212 |
+
self.tok_emb = nn.Embedding(vocab_size, n_embd)
|
| 213 |
+
self.pos_emb = nn.Embedding(block_size, n_embd)
|
| 214 |
+
self.blocks = nn.Sequential(
|
| 215 |
+
*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]
|
| 216 |
+
)
|
| 217 |
+
self.ln_f = nn.LayerNorm(n_embd)
|
| 218 |
+
self.lm_head = nn.Linear(n_embd, vocab_size)
|
| 219 |
+
|
| 220 |
+
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
|
| 221 |
+
B, T = idx.shape
|
| 222 |
+
pos = torch.arange(T, device=idx.device)
|
| 223 |
+
x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
|
| 224 |
+
x = self.blocks(x)
|
| 225 |
+
x = self.ln_f(x)
|
| 226 |
+
logits = self.lm_head(x)
|
| 227 |
+
|
| 228 |
+
loss = None
|
| 229 |
+
if targets is not None:
|
| 230 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
| 231 |
+
return logits, loss
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def get_sparse_linears(model: nn.Module) -> List[SparseLinear]:
|
| 235 |
+
return [m for m in model.modules() if isinstance(m, SparseLinear)]
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# -----------------------------
|
| 239 |
+
# Radar scheduler
|
| 240 |
+
# -----------------------------
|
| 241 |
+
|
| 242 |
+
class RadarScheduler:
|
| 243 |
+
def __init__(
|
| 244 |
+
self,
|
| 245 |
+
model: nn.Module,
|
| 246 |
+
chunk_size: int,
|
| 247 |
+
active_fraction: float,
|
| 248 |
+
device: str,
|
| 249 |
+
scheduler: Scheduler,
|
| 250 |
+
alpha: float,
|
| 251 |
+
mass_beta: float = 0.95,
|
| 252 |
+
similarity_history: int = 128,
|
| 253 |
+
min_similarity_history: int = 8,
|
| 254 |
+
):
|
| 255 |
+
self.model = model
|
| 256 |
+
self.chunk_size = chunk_size
|
| 257 |
+
self.active_fraction = active_fraction
|
| 258 |
+
self.device = device
|
| 259 |
+
self.scheduler = scheduler
|
| 260 |
+
self.alpha = float(alpha)
|
| 261 |
+
self.mass_beta = mass_beta
|
| 262 |
+
self.similarity_history = similarity_history
|
| 263 |
+
self.min_similarity_history = min_similarity_history
|
| 264 |
+
|
| 265 |
+
self.linears = get_sparse_linears(model)
|
| 266 |
+
self.module_to_chunk_ids: Dict[nn.Module, torch.Tensor] = {}
|
| 267 |
+
self.chunk_to_module_local: List[Tuple[nn.Module, int]] = []
|
| 268 |
+
|
| 269 |
+
offset = 0
|
| 270 |
+
for m in self.linears:
|
| 271 |
+
assert m.out_features % chunk_size == 0, (
|
| 272 |
+
f"out_features {m.out_features} not divisible by chunk_size {chunk_size}"
|
| 273 |
+
)
|
| 274 |
+
n_chunks = m.out_features // chunk_size
|
| 275 |
+
ids = torch.arange(offset, offset + n_chunks, device=device)
|
| 276 |
+
self.module_to_chunk_ids[m] = ids
|
| 277 |
+
for local_c in range(n_chunks):
|
| 278 |
+
self.chunk_to_module_local.append((m, local_c))
|
| 279 |
+
offset += n_chunks
|
| 280 |
+
|
| 281 |
+
self.n_chunks = offset
|
| 282 |
+
self.predicted_mass = torch.zeros(self.n_chunks, device=device)
|
| 283 |
+
self.mass_history: List[torch.Tensor] = []
|
| 284 |
+
|
| 285 |
+
self.current_mask = torch.ones(self.n_chunks, dtype=torch.bool, device=device)
|
| 286 |
+
self.next_sensor_scores = torch.zeros(self.n_chunks, device=device)
|
| 287 |
+
self.next_scores = torch.zeros(self.n_chunks, device=device)
|
| 288 |
+
self.prev_mask: Optional[torch.Tensor] = None
|
| 289 |
+
self.similarity: Optional[torch.Tensor] = None
|
| 290 |
+
|
| 291 |
+
def k_active(self) -> int:
|
| 292 |
+
return max(1, int(self.active_fraction * self.n_chunks))
|
| 293 |
+
|
| 294 |
+
def choose_mask(self, step: int, warmup_steps: int) -> torch.Tensor:
|
| 295 |
+
if self.scheduler == "dense" or step < warmup_steps:
|
| 296 |
+
self.current_mask = torch.ones(self.n_chunks, dtype=torch.bool, device=self.device)
|
| 297 |
+
return self.current_mask
|
| 298 |
+
|
| 299 |
+
k = self.k_active()
|
| 300 |
+
mask = torch.zeros(self.n_chunks, dtype=torch.bool, device=self.device)
|
| 301 |
+
|
| 302 |
+
if self.scheduler == "random":
|
| 303 |
+
idx = torch.randperm(self.n_chunks, device=self.device)[:k]
|
| 304 |
+
else:
|
| 305 |
+
scores = self.score_for_selection()
|
| 306 |
+
scores = scores + 1e-9 * torch.rand_like(scores)
|
| 307 |
+
idx = torch.topk(scores, k=k).indices
|
| 308 |
+
|
| 309 |
+
mask[idx] = True
|
| 310 |
+
self.current_mask = mask
|
| 311 |
+
return mask
|
| 312 |
+
|
| 313 |
+
def score_for_selection(self) -> torch.Tensor:
|
| 314 |
+
if self.scheduler == "random":
|
| 315 |
+
return torch.zeros_like(self.predicted_mass)
|
| 316 |
+
|
| 317 |
+
if self.scheduler == "ema_topk":
|
| 318 |
+
return self.predicted_mass
|
| 319 |
+
|
| 320 |
+
if self.scheduler in ("knn_scheduler", "graph_scheduler"):
|
| 321 |
+
if torch.count_nonzero(self.next_sensor_scores).item() == 0:
|
| 322 |
+
return self.predicted_mass
|
| 323 |
+
return self.next_sensor_scores
|
| 324 |
+
|
| 325 |
+
if self.scheduler in ("radar_knn", "radar_graph"):
|
| 326 |
+
sensor = self.next_sensor_scores
|
| 327 |
+
if torch.count_nonzero(sensor).item() == 0:
|
| 328 |
+
sensor = self.predicted_mass
|
| 329 |
+
|
| 330 |
+
ema_n = normalize_scores(self.predicted_mass)
|
| 331 |
+
sensor_n = normalize_scores(sensor)
|
| 332 |
+
return self.alpha * ema_n + (1.0 - self.alpha) * sensor_n
|
| 333 |
+
|
| 334 |
+
if self.scheduler == "dense":
|
| 335 |
+
return torch.ones_like(self.predicted_mass)
|
| 336 |
+
|
| 337 |
+
raise ValueError(f"Unknown scheduler: {self.scheduler}")
|
| 338 |
+
|
| 339 |
+
@torch.no_grad()
|
| 340 |
+
def chunk_gradient_vectors(self) -> List[torch.Tensor]:
|
| 341 |
+
vecs: List[torch.Tensor] = []
|
| 342 |
+
for m, local_c in self.chunk_to_module_local:
|
| 343 |
+
start = local_c * self.chunk_size
|
| 344 |
+
end = (local_c + 1) * self.chunk_size
|
| 345 |
+
|
| 346 |
+
parts = []
|
| 347 |
+
if m.weight.grad is None:
|
| 348 |
+
parts.append(torch.zeros_like(m.weight[start:end]).flatten())
|
| 349 |
+
else:
|
| 350 |
+
parts.append(m.weight.grad[start:end].detach().flatten())
|
| 351 |
+
|
| 352 |
+
if m.bias is not None:
|
| 353 |
+
if m.bias.grad is None:
|
| 354 |
+
parts.append(torch.zeros_like(m.bias[start:end]).flatten())
|
| 355 |
+
else:
|
| 356 |
+
parts.append(m.bias.grad[start:end].detach().flatten())
|
| 357 |
+
|
| 358 |
+
vecs.append(torch.cat(parts))
|
| 359 |
+
return vecs
|
| 360 |
+
|
| 361 |
+
@torch.no_grad()
|
| 362 |
+
def chunk_masses_from_vecs(self, vecs: List[torch.Tensor]) -> torch.Tensor:
|
| 363 |
+
return torch.stack([v.norm() for v in vecs]).to(self.device)
|
| 364 |
+
|
| 365 |
+
@torch.no_grad()
|
| 366 |
+
def update_from_observed(
|
| 367 |
+
self,
|
| 368 |
+
observed_mask: torch.Tensor,
|
| 369 |
+
true_masses: torch.Tensor,
|
| 370 |
+
step: int,
|
| 371 |
+
warmup_steps: int,
|
| 372 |
+
) -> None:
|
| 373 |
+
observed = observed_mask
|
| 374 |
+
|
| 375 |
+
never_seen = observed & (self.predicted_mass == 0)
|
| 376 |
+
already_seen = observed & ~never_seen
|
| 377 |
+
|
| 378 |
+
self.predicted_mass[never_seen] = true_masses[never_seen]
|
| 379 |
+
self.predicted_mass[already_seen] = (
|
| 380 |
+
self.mass_beta * self.predicted_mass[already_seen]
|
| 381 |
+
+ (1.0 - self.mass_beta) * true_masses[already_seen]
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
# Dense warmup teaches the similarity graph.
|
| 385 |
+
if step < warmup_steps:
|
| 386 |
+
self.mass_history.append(true_masses.detach().clone())
|
| 387 |
+
if len(self.mass_history) > self.similarity_history:
|
| 388 |
+
self.mass_history = self.mass_history[-self.similarity_history :]
|
| 389 |
+
|
| 390 |
+
if len(self.mass_history) >= self.min_similarity_history:
|
| 391 |
+
self.similarity = self.build_similarity()
|
| 392 |
+
|
| 393 |
+
if self.scheduler in ("knn_scheduler", "radar_knn"):
|
| 394 |
+
self.next_sensor_scores = self.knn_scores(observed, true_masses)
|
| 395 |
+
elif self.scheduler in ("graph_scheduler", "radar_graph"):
|
| 396 |
+
self.next_sensor_scores = self.diffusion_scores(observed, true_masses)
|
| 397 |
+
else:
|
| 398 |
+
self.next_sensor_scores = self.predicted_mass.clone()
|
| 399 |
+
|
| 400 |
+
self.next_scores = self.score_for_selection()
|
| 401 |
+
|
| 402 |
+
def build_similarity(self) -> torch.Tensor:
|
| 403 |
+
H = torch.stack(self.mass_history, dim=0) # [history, chunks]
|
| 404 |
+
H = H - H.mean(dim=0, keepdim=True)
|
| 405 |
+
H = H / (H.std(dim=0, keepdim=True) + 1e-6)
|
| 406 |
+
|
| 407 |
+
S = (H.T @ H) / max(1, H.shape[0] - 1)
|
| 408 |
+
S = torch.clamp(S, min=0.0)
|
| 409 |
+
S.fill_diagonal_(0.0)
|
| 410 |
+
|
| 411 |
+
# Within-layer only. This keeps the graph interpretable and avoids
|
| 412 |
+
# overfitting tiny cross-layer coincidences.
|
| 413 |
+
allowed = torch.zeros_like(S, dtype=torch.bool)
|
| 414 |
+
for _, ids in self.module_to_chunk_ids.items():
|
| 415 |
+
allowed[ids[:, None], ids[None, :]] = True
|
| 416 |
+
|
| 417 |
+
S = torch.where(allowed, S, torch.zeros_like(S))
|
| 418 |
+
return S
|
| 419 |
+
|
| 420 |
+
def knn_scores(
|
| 421 |
+
self,
|
| 422 |
+
active_mask: torch.Tensor,
|
| 423 |
+
true_masses: torch.Tensor,
|
| 424 |
+
k_neighbors: int = 3,
|
| 425 |
+
) -> torch.Tensor:
|
| 426 |
+
if self.similarity is None:
|
| 427 |
+
return self.predicted_mass.clone()
|
| 428 |
+
|
| 429 |
+
S = self.similarity
|
| 430 |
+
scores = self.predicted_mass.clone()
|
| 431 |
+
scores[active_mask] = true_masses[active_mask]
|
| 432 |
+
|
| 433 |
+
active_idx = torch.nonzero(active_mask, as_tuple=False).flatten()
|
| 434 |
+
inactive_idx = torch.nonzero(~active_mask, as_tuple=False).flatten()
|
| 435 |
+
|
| 436 |
+
if active_idx.numel() == 0:
|
| 437 |
+
return scores
|
| 438 |
+
|
| 439 |
+
for i in inactive_idx.tolist():
|
| 440 |
+
weights = S[i, active_idx]
|
| 441 |
+
if weights.sum() <= 1e-12:
|
| 442 |
+
continue
|
| 443 |
+
kk = min(k_neighbors, weights.numel())
|
| 444 |
+
top = torch.topk(weights, k=kk)
|
| 445 |
+
w = top.values
|
| 446 |
+
aidx = active_idx[top.indices]
|
| 447 |
+
scores[i] = (w * true_masses[aidx]).sum() / (w.sum() + 1e-12)
|
| 448 |
+
|
| 449 |
+
return scores
|
| 450 |
+
|
| 451 |
+
def diffusion_scores(
|
| 452 |
+
self,
|
| 453 |
+
active_mask: torch.Tensor,
|
| 454 |
+
true_masses: torch.Tensor,
|
| 455 |
+
diffusion_steps: int = 8,
|
| 456 |
+
alpha: float = 0.7,
|
| 457 |
+
) -> torch.Tensor:
|
| 458 |
+
if self.similarity is None:
|
| 459 |
+
return self.predicted_mass.clone()
|
| 460 |
+
|
| 461 |
+
S = self.similarity
|
| 462 |
+
W = S / (S.sum(dim=1, keepdim=True) + 1e-12)
|
| 463 |
+
|
| 464 |
+
scores = self.predicted_mass.clone()
|
| 465 |
+
scores[active_mask] = true_masses[active_mask]
|
| 466 |
+
|
| 467 |
+
for _ in range(diffusion_steps):
|
| 468 |
+
proposal = W @ scores
|
| 469 |
+
scores = alpha * proposal + (1.0 - alpha) * scores
|
| 470 |
+
scores[active_mask] = true_masses[active_mask]
|
| 471 |
+
|
| 472 |
+
return torch.clamp(scores, min=0.0)
|
| 473 |
+
|
| 474 |
+
def oracle_topk_mask(self, true_masses: torch.Tensor) -> torch.Tensor:
|
| 475 |
+
k = self.k_active()
|
| 476 |
+
mask = torch.zeros(self.n_chunks, dtype=torch.bool, device=self.device)
|
| 477 |
+
mask[torch.topk(true_masses, k=k).indices] = True
|
| 478 |
+
return mask
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
# -----------------------------
|
| 482 |
+
# Gradient installation and metrics
|
| 483 |
+
# -----------------------------
|
| 484 |
+
|
| 485 |
+
@torch.no_grad()
|
| 486 |
+
def install_active_only_grads(sched: RadarScheduler, active_mask: torch.Tensor) -> None:
|
| 487 |
+
if sched.scheduler == "dense":
|
| 488 |
+
return
|
| 489 |
+
|
| 490 |
+
for m, ids in sched.module_to_chunk_ids.items():
|
| 491 |
+
local_active = active_mask[ids]
|
| 492 |
+
|
| 493 |
+
if m.weight.grad is not None:
|
| 494 |
+
for local_c, is_active in enumerate(local_active.tolist()):
|
| 495 |
+
if not is_active:
|
| 496 |
+
start = local_c * sched.chunk_size
|
| 497 |
+
end = (local_c + 1) * sched.chunk_size
|
| 498 |
+
m.weight.grad[start:end].zero_()
|
| 499 |
+
|
| 500 |
+
if m.bias is not None and m.bias.grad is not None:
|
| 501 |
+
for local_c, is_active in enumerate(local_active.tolist()):
|
| 502 |
+
if not is_active:
|
| 503 |
+
start = local_c * sched.chunk_size
|
| 504 |
+
end = (local_c + 1) * sched.chunk_size
|
| 505 |
+
m.bias.grad[start:end].zero_()
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
def dense_cosine_active_only(vecs: List[torch.Tensor], active_mask: torch.Tensor) -> float:
|
| 509 |
+
true = torch.cat([v.flatten() for v in vecs])
|
| 510 |
+
approx_parts = []
|
| 511 |
+
for i, v in enumerate(vecs):
|
| 512 |
+
approx_parts.append(v.flatten() if bool(active_mask[i]) else torch.zeros_like(v).flatten())
|
| 513 |
+
approx = torch.cat(approx_parts)
|
| 514 |
+
return float(F.cosine_similarity(true, approx, dim=0).item())
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
def jaccard(a: torch.Tensor, b: torch.Tensor) -> float:
|
| 518 |
+
inter = (a & b).sum().float()
|
| 519 |
+
union = (a | b).sum().float()
|
| 520 |
+
return float((inter / torch.clamp(union, min=1.0)).item())
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
class SimpleAdam:
|
| 524 |
+
def __init__(self, model: nn.Module, lr: float = 3e-4):
|
| 525 |
+
self.model = model
|
| 526 |
+
self.lr = lr
|
| 527 |
+
self.state: Dict[torch.nn.Parameter, Dict[str, torch.Tensor]] = {}
|
| 528 |
+
|
| 529 |
+
def zero_grad(self):
|
| 530 |
+
for p in self.model.parameters():
|
| 531 |
+
p.grad = None
|
| 532 |
+
|
| 533 |
+
@torch.no_grad()
|
| 534 |
+
def step(self):
|
| 535 |
+
for p in self.model.parameters():
|
| 536 |
+
if p.grad is None:
|
| 537 |
+
continue
|
| 538 |
+
if p not in self.state:
|
| 539 |
+
self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
|
| 540 |
+
m = self.state[p]["m"]
|
| 541 |
+
v = self.state[p]["v"]
|
| 542 |
+
m.mul_(0.9).add_(p.grad, alpha=0.1)
|
| 543 |
+
v.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
|
| 544 |
+
p.sub_(m / (torch.sqrt(v) + 1e-8), alpha=self.lr)
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def evaluate(model: nn.Module, corpus: CharCorpus, batch_size: int, seed: int) -> float:
|
| 548 |
+
model.eval()
|
| 549 |
+
with torch.no_grad():
|
| 550 |
+
x, y = corpus.get_batch("val", batch_size, generator=make_cpu_generator(seed))
|
| 551 |
+
_, loss = model(x, y)
|
| 552 |
+
model.train()
|
| 553 |
+
return float(loss.item())
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def run_experiment(
|
| 557 |
+
scheduler_name: Scheduler,
|
| 558 |
+
alpha: float,
|
| 559 |
+
device: str,
|
| 560 |
+
steps: int,
|
| 561 |
+
batch_size: int,
|
| 562 |
+
block_size: int,
|
| 563 |
+
n_layer: int,
|
| 564 |
+
n_head: int,
|
| 565 |
+
n_embd: int,
|
| 566 |
+
chunk_size: int,
|
| 567 |
+
active_fraction: float,
|
| 568 |
+
warmup_steps: int,
|
| 569 |
+
benchmark_sync: bool,
|
| 570 |
+
) -> Dict[str, float]:
|
| 571 |
+
set_seed(42)
|
| 572 |
+
|
| 573 |
+
corpus = CharCorpus(make_synthetic_corpus(), block_size, device)
|
| 574 |
+
model = MiniGPT(corpus.vocab_size, block_size, n_layer, n_head, n_embd, 0.0).to(device)
|
| 575 |
+
opt = SimpleAdam(model, lr=3e-4)
|
| 576 |
+
sched = RadarScheduler(
|
| 577 |
+
model=model,
|
| 578 |
+
chunk_size=chunk_size,
|
| 579 |
+
active_fraction=active_fraction,
|
| 580 |
+
device=device,
|
| 581 |
+
scheduler=scheduler_name,
|
| 582 |
+
alpha=alpha,
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
metric_rows = []
|
| 586 |
+
|
| 587 |
+
if benchmark_sync:
|
| 588 |
+
sync_device(device)
|
| 589 |
+
t0 = time.perf_counter()
|
| 590 |
+
|
| 591 |
+
for step in range(steps):
|
| 592 |
+
x, y = corpus.get_batch("train", batch_size, generator=make_cpu_generator(step))
|
| 593 |
+
|
| 594 |
+
active_mask = sched.choose_mask(step=step, warmup_steps=warmup_steps)
|
| 595 |
+
|
| 596 |
+
opt.zero_grad()
|
| 597 |
+
_, loss = model(x, y)
|
| 598 |
+
loss.backward()
|
| 599 |
+
|
| 600 |
+
vecs = sched.chunk_gradient_vectors()
|
| 601 |
+
masses = sched.chunk_masses_from_vecs(vecs)
|
| 602 |
+
|
| 603 |
+
if step >= warmup_steps and scheduler_name != "dense":
|
| 604 |
+
oracle = sched.oracle_topk_mask(masses)
|
| 605 |
+
metric_rows.append(
|
| 606 |
+
{
|
| 607 |
+
"cos": dense_cosine_active_only(vecs, active_mask),
|
| 608 |
+
"jacc": jaccard(active_mask, oracle),
|
| 609 |
+
"stable": jaccard(active_mask, sched.prev_mask) if sched.prev_mask is not None else 0.0,
|
| 610 |
+
}
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
install_active_only_grads(sched, active_mask)
|
| 614 |
+
|
| 615 |
+
# The scheduler only learns from active chunks after warmup.
|
| 616 |
+
# During warmup it observes everything to build the similarity graph.
|
| 617 |
+
observed_for_scheduler = active_mask if step >= warmup_steps else torch.ones_like(active_mask)
|
| 618 |
+
sched.update_from_observed(
|
| 619 |
+
observed_mask=observed_for_scheduler,
|
| 620 |
+
true_masses=masses,
|
| 621 |
+
step=step,
|
| 622 |
+
warmup_steps=warmup_steps,
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
sched.prev_mask = active_mask.clone()
|
| 626 |
+
opt.step()
|
| 627 |
+
|
| 628 |
+
if benchmark_sync:
|
| 629 |
+
sync_device(device)
|
| 630 |
+
elapsed = time.perf_counter() - t0
|
| 631 |
+
|
| 632 |
+
val_loss = evaluate(model, corpus, batch_size, seed=12345)
|
| 633 |
+
|
| 634 |
+
if metric_rows:
|
| 635 |
+
avg_cos = sum(r["cos"] for r in metric_rows) / len(metric_rows)
|
| 636 |
+
avg_jacc = sum(r["jacc"] for r in metric_rows) / len(metric_rows)
|
| 637 |
+
avg_stable = sum(r["stable"] for r in metric_rows) / len(metric_rows)
|
| 638 |
+
else:
|
| 639 |
+
avg_cos = float("nan")
|
| 640 |
+
avg_jacc = float("nan")
|
| 641 |
+
avg_stable = float("nan")
|
| 642 |
+
|
| 643 |
+
return {
|
| 644 |
+
"val": val_loss,
|
| 645 |
+
"ms": 1000.0 * elapsed / steps,
|
| 646 |
+
"cos": avg_cos,
|
| 647 |
+
"jacc": avg_jacc,
|
| 648 |
+
"stable": avg_stable,
|
| 649 |
+
}
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
def build_runs(alphas: List[float]) -> List[Tuple[Scheduler, float, str]]:
|
| 653 |
+
runs: List[Tuple[Scheduler, float, str]] = [
|
| 654 |
+
("dense", 1.0, "dense"),
|
| 655 |
+
("ema_topk", 1.0, "ema_topk"),
|
| 656 |
+
("knn_scheduler", 0.0, "knn"),
|
| 657 |
+
("graph_scheduler", 0.0, "graph"),
|
| 658 |
+
]
|
| 659 |
+
|
| 660 |
+
for a in alphas:
|
| 661 |
+
runs.append(("radar_knn", a, f"radar_knn_a{a:g}"))
|
| 662 |
+
for a in alphas:
|
| 663 |
+
runs.append(("radar_graph", a, f"radar_graph_a{a:g}"))
|
| 664 |
+
|
| 665 |
+
runs.append(("random", 0.0, "random"))
|
| 666 |
+
return runs
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
def main() -> None:
|
| 670 |
+
parser = argparse.ArgumentParser()
|
| 671 |
+
parser.add_argument("--steps", type=int, default=500)
|
| 672 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
| 673 |
+
parser.add_argument("--block_size", type=int, default=128)
|
| 674 |
+
parser.add_argument("--n_layer", type=int, default=4)
|
| 675 |
+
parser.add_argument("--n_head", type=int, default=8)
|
| 676 |
+
parser.add_argument("--n_embd", type=int, default=512)
|
| 677 |
+
parser.add_argument("--chunk_size", type=int, default=64)
|
| 678 |
+
parser.add_argument("--active_fraction", type=float, default=0.10)
|
| 679 |
+
parser.add_argument("--warmup_steps", type=int, default=25)
|
| 680 |
+
parser.add_argument("--alphas", type=float, nargs="+", default=[0.25, 0.5, 0.75, 0.9])
|
| 681 |
+
parser.add_argument("--device", type=str, default="mps")
|
| 682 |
+
parser.add_argument("--benchmark_sync", action="store_true")
|
| 683 |
+
args = parser.parse_args()
|
| 684 |
+
|
| 685 |
+
runs = build_runs(args.alphas)
|
| 686 |
+
|
| 687 |
+
print("\nRadar scheduler diagnostic")
|
| 688 |
+
print(f"device={args.device} steps={args.steps} d={args.n_embd} chunks={args.chunk_size}")
|
| 689 |
+
print(f"active_fraction={args.active_fraction} warmup={args.warmup_steps}")
|
| 690 |
+
print(f"alphas={args.alphas}\n")
|
| 691 |
+
print(
|
| 692 |
+
f"{'run':>18s} | {'val':>8s} | {'ms/step':>8s} | "
|
| 693 |
+
f"{'grad_cos':>8s} | {'jacc':>8s} | {'stable':>8s}"
|
| 694 |
+
)
|
| 695 |
+
print("-" * 78)
|
| 696 |
+
|
| 697 |
+
for scheduler_name, alpha, label in runs:
|
| 698 |
+
result = run_experiment(
|
| 699 |
+
scheduler_name=scheduler_name,
|
| 700 |
+
alpha=alpha,
|
| 701 |
+
device=args.device,
|
| 702 |
+
steps=args.steps,
|
| 703 |
+
batch_size=args.batch_size,
|
| 704 |
+
block_size=args.block_size,
|
| 705 |
+
n_layer=args.n_layer,
|
| 706 |
+
n_head=args.n_head,
|
| 707 |
+
n_embd=args.n_embd,
|
| 708 |
+
chunk_size=args.chunk_size,
|
| 709 |
+
active_fraction=args.active_fraction,
|
| 710 |
+
warmup_steps=args.warmup_steps,
|
| 711 |
+
benchmark_sync=args.benchmark_sync,
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
print(
|
| 715 |
+
f"{label:>18s} | "
|
| 716 |
+
f"{result['val']:8.4f} | "
|
| 717 |
+
f"{result['ms']:8.2f} | "
|
| 718 |
+
f"{result['cos']:8.3f} | "
|
| 719 |
+
f"{result['jacc']:8.3f} | "
|
| 720 |
+
f"{result['stable']:8.3f}"
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
if __name__ == "__main__":
|
| 725 |
+
main()
|
experiments/sparse_transformer_v6.py
ADDED
|
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sparse Transformer v6: stable predicted-magnitude masks, no dense refresh by default.
|
| 3 |
+
|
| 4 |
+
This prototype is designed to test the next hypothesis after the spiral/MLP runs:
|
| 5 |
+
|
| 6 |
+
The important gradient support is heavy-tailed and temporally stable enough
|
| 7 |
+
that we can select active parameter blocks from history, freeze the rest, and
|
| 8 |
+
still train a harder sequence model.
|
| 9 |
+
|
| 10 |
+
Key fixes versus v5
|
| 11 |
+
-------------------
|
| 12 |
+
1. Harder model: a small causal Transformer language model.
|
| 13 |
+
2. No periodic dense refresh by default: --warmup_steps 0.
|
| 14 |
+
3. The selector only learns from blocks it actually observes/updates.
|
| 15 |
+
4. Inactive Linear rows are truly frozen by MaskedAdam. This matters because
|
| 16 |
+
ordinary Adam can still move parameters with zero gradients through momentum.
|
| 17 |
+
5. A true current-step oracle is included as an audit upper bound.
|
| 18 |
+
6. Random masks are included as a control.
|
| 19 |
+
|
| 20 |
+
Important limitation
|
| 21 |
+
--------------------
|
| 22 |
+
This still calls loss.backward(), so PyTorch computes dense gradients. Those full
|
| 23 |
+
current gradients are used for audit metrics and for the oracle run only. The
|
| 24 |
+
practical predicted_magnitude selector is not allowed to update its statistics
|
| 25 |
+
from inactive full gradients.
|
| 26 |
+
|
| 27 |
+
Actual speedup would require structured partial backward/custom kernels.
|
| 28 |
+
|
| 29 |
+
Run
|
| 30 |
+
---
|
| 31 |
+
python3 sparse_transformer_v6.py --quick
|
| 32 |
+
python3 sparse_transformer_v6.py --steps 1000 --active_fractions 0.10 0.05 0.02
|
| 33 |
+
python3 sparse_transformer_v6.py --text_path input.txt --steps 2000
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
from __future__ import annotations
|
| 37 |
+
|
| 38 |
+
import argparse
|
| 39 |
+
import math
|
| 40 |
+
import random
|
| 41 |
+
from typing import Dict, List, Literal, Optional, Tuple
|
| 42 |
+
|
| 43 |
+
import torch
|
| 44 |
+
torch.set_num_threads(1)
|
| 45 |
+
import torch.nn as nn
|
| 46 |
+
import torch.nn.functional as F
|
| 47 |
+
|
| 48 |
+
Policy = Literal["predicted_magnitude", "oracle_current", "random"]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def set_seed(seed: int) -> None:
|
| 52 |
+
random.seed(seed)
|
| 53 |
+
torch.manual_seed(seed)
|
| 54 |
+
if torch.cuda.is_available():
|
| 55 |
+
torch.cuda.manual_seed_all(seed)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def device() -> str:
|
| 59 |
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# -----------------------------
|
| 63 |
+
# Data
|
| 64 |
+
# -----------------------------
|
| 65 |
+
|
| 66 |
+
def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
|
| 67 |
+
rng = random.Random(seed)
|
| 68 |
+
names = ["ada", "turing", "grace", "lovelace", "noether", "shannon", "hopper", "gauss"]
|
| 69 |
+
verbs = ["builds", "tests", "traces", "compresses", "predicts", "routes", "writes", "measures"]
|
| 70 |
+
objects = ["signals", "gradients", "tokens", "circuits", "features", "masks", "errors", "states"]
|
| 71 |
+
adverbs = ["quietly", "boldly", "slowly", "quickly", "cleanly", "strangely", "carefully"]
|
| 72 |
+
clauses = [
|
| 73 |
+
"when the loss falls",
|
| 74 |
+
"after the mask shifts",
|
| 75 |
+
"before the model answers",
|
| 76 |
+
"while the signal drifts",
|
| 77 |
+
"if the pattern repeats",
|
| 78 |
+
"because the tail is noisy",
|
| 79 |
+
]
|
| 80 |
+
symbols = ["alpha", "beta", "gamma", "delta", "omega", "sigma"]
|
| 81 |
+
|
| 82 |
+
lines: List[str] = []
|
| 83 |
+
for _ in range(n_sentences):
|
| 84 |
+
t = rng.randrange(6)
|
| 85 |
+
if t == 0:
|
| 86 |
+
line = f"{rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} {rng.choice(adverbs)}."
|
| 87 |
+
elif t == 1:
|
| 88 |
+
line = f"{rng.choice(clauses)}, {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)}."
|
| 89 |
+
elif t == 2:
|
| 90 |
+
a, b = rng.sample(symbols, 2)
|
| 91 |
+
line = f"rule {a}: {rng.choice(objects)} -> {rng.choice(objects)}; rule {b}: {rng.choice(objects)} -> {rng.choice(objects)}."
|
| 92 |
+
elif t == 3:
|
| 93 |
+
line = f"the {rng.choice(objects)} {rng.choice(verbs)} the {rng.choice(objects)} {rng.choice(adverbs)}."
|
| 94 |
+
elif t == 4:
|
| 95 |
+
seq = " ".join(rng.choice(symbols) for _ in range(rng.randint(2, 7)))
|
| 96 |
+
line = f"sequence {seq} ends when {rng.choice(names)} {rng.choice(verbs)}."
|
| 97 |
+
else:
|
| 98 |
+
line = f"if {rng.choice(objects)} rise then {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} else wait."
|
| 99 |
+
lines.append(line)
|
| 100 |
+
return "\n".join(lines) + "\n"
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class CharCorpus:
|
| 104 |
+
def __init__(self, text: str, block_size: int, device: str):
|
| 105 |
+
chars = sorted(set(text))
|
| 106 |
+
self.stoi = {ch: i for i, ch in enumerate(chars)}
|
| 107 |
+
self.itos = {i: ch for ch, i in self.stoi.items()}
|
| 108 |
+
self.vocab_size = len(chars)
|
| 109 |
+
self.block_size = block_size
|
| 110 |
+
self.device = device
|
| 111 |
+
|
| 112 |
+
data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
|
| 113 |
+
split = int(0.9 * len(data))
|
| 114 |
+
self.train_data = data[:split]
|
| 115 |
+
self.val_data = data[split:]
|
| 116 |
+
|
| 117 |
+
def get_batch(self, split: str, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 118 |
+
data = self.train_data if split == "train" else self.val_data
|
| 119 |
+
max_start = len(data) - self.block_size - 1
|
| 120 |
+
if max_start <= 0:
|
| 121 |
+
raise ValueError("Corpus too small for block_size")
|
| 122 |
+
ix = torch.randint(max_start, (batch_size,))
|
| 123 |
+
x = torch.stack([data[i : i + self.block_size] for i in ix])
|
| 124 |
+
y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
|
| 125 |
+
return x.to(self.device), y.to(self.device)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def load_text(args: argparse.Namespace) -> str:
|
| 129 |
+
if args.text_path:
|
| 130 |
+
with open(args.text_path, "r", encoding="utf-8") as f:
|
| 131 |
+
return f.read()
|
| 132 |
+
return make_synthetic_corpus(args.synthetic_sentences, args.seed)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# -----------------------------
|
| 136 |
+
# Mini GPT
|
| 137 |
+
# -----------------------------
|
| 138 |
+
|
| 139 |
+
class CausalSelfAttention(nn.Module):
|
| 140 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 141 |
+
super().__init__()
|
| 142 |
+
assert n_embd % n_head == 0
|
| 143 |
+
self.n_head = n_head
|
| 144 |
+
self.head_dim = n_embd // n_head
|
| 145 |
+
self.c_attn = nn.Linear(n_embd, 3 * n_embd)
|
| 146 |
+
self.c_proj = nn.Linear(n_embd, n_embd)
|
| 147 |
+
self.dropout = nn.Dropout(dropout)
|
| 148 |
+
self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
|
| 149 |
+
|
| 150 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 151 |
+
B, T, C = x.shape
|
| 152 |
+
qkv = self.c_attn(x)
|
| 153 |
+
q, k, v = qkv.split(C, dim=2)
|
| 154 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 155 |
+
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 156 |
+
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 157 |
+
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 158 |
+
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
|
| 159 |
+
att = F.softmax(att, dim=-1)
|
| 160 |
+
att = self.dropout(att)
|
| 161 |
+
y = att @ v
|
| 162 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 163 |
+
return self.c_proj(y)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class FeedForward(nn.Module):
|
| 167 |
+
def __init__(self, n_embd: int, dropout: float):
|
| 168 |
+
super().__init__()
|
| 169 |
+
self.c_fc = nn.Linear(n_embd, 4 * n_embd)
|
| 170 |
+
self.c_proj = nn.Linear(4 * n_embd, n_embd)
|
| 171 |
+
self.dropout = nn.Dropout(dropout)
|
| 172 |
+
|
| 173 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 174 |
+
return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class Block(nn.Module):
|
| 178 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 179 |
+
super().__init__()
|
| 180 |
+
self.ln1 = nn.LayerNorm(n_embd)
|
| 181 |
+
self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
|
| 182 |
+
self.ln2 = nn.LayerNorm(n_embd)
|
| 183 |
+
self.mlp = FeedForward(n_embd, dropout)
|
| 184 |
+
|
| 185 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 186 |
+
x = x + self.attn(self.ln1(x))
|
| 187 |
+
x = x + self.mlp(self.ln2(x))
|
| 188 |
+
return x
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class MiniGPT(nn.Module):
|
| 192 |
+
def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float):
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.block_size = block_size
|
| 195 |
+
self.tok_emb = nn.Embedding(vocab_size, n_embd)
|
| 196 |
+
self.pos_emb = nn.Embedding(block_size, n_embd)
|
| 197 |
+
self.drop = nn.Dropout(dropout)
|
| 198 |
+
self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
|
| 199 |
+
self.ln_f = nn.LayerNorm(n_embd)
|
| 200 |
+
self.lm_head = nn.Linear(n_embd, vocab_size)
|
| 201 |
+
|
| 202 |
+
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
|
| 203 |
+
B, T = idx.shape
|
| 204 |
+
pos = torch.arange(T, device=idx.device)
|
| 205 |
+
x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
|
| 206 |
+
x = self.drop(x)
|
| 207 |
+
x = self.blocks(x)
|
| 208 |
+
x = self.ln_f(x)
|
| 209 |
+
logits = self.lm_head(x)
|
| 210 |
+
loss = None
|
| 211 |
+
if targets is not None:
|
| 212 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
| 213 |
+
return logits, loss
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def named_linear_modules(model: nn.Module) -> List[Tuple[str, nn.Linear]]:
|
| 217 |
+
return [(name, m) for name, m in model.named_modules() if isinstance(m, nn.Linear)]
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# -----------------------------
|
| 221 |
+
# Mask selector
|
| 222 |
+
# -----------------------------
|
| 223 |
+
|
| 224 |
+
class RowMasker:
|
| 225 |
+
def __init__(
|
| 226 |
+
self,
|
| 227 |
+
model: nn.Module,
|
| 228 |
+
policy: Policy,
|
| 229 |
+
active_fraction: float,
|
| 230 |
+
explore_fraction: float,
|
| 231 |
+
mass_beta: float,
|
| 232 |
+
unobserved_decay: float,
|
| 233 |
+
warmup_steps: int,
|
| 234 |
+
device: str,
|
| 235 |
+
):
|
| 236 |
+
self.model = model
|
| 237 |
+
self.policy = policy
|
| 238 |
+
self.active_fraction = active_fraction
|
| 239 |
+
self.explore_fraction = explore_fraction
|
| 240 |
+
self.mass_beta = mass_beta
|
| 241 |
+
self.unobserved_decay = unobserved_decay
|
| 242 |
+
self.warmup_steps = warmup_steps
|
| 243 |
+
self.device = device
|
| 244 |
+
|
| 245 |
+
self.linear_modules = [m for _, m in named_linear_modules(model)]
|
| 246 |
+
self.module_to_ids: Dict[nn.Linear, torch.Tensor] = {}
|
| 247 |
+
ids = []
|
| 248 |
+
offset = 0
|
| 249 |
+
for m in self.linear_modules:
|
| 250 |
+
n = m.weight.shape[0]
|
| 251 |
+
block_ids = torch.arange(offset, offset + n, device=device)
|
| 252 |
+
self.module_to_ids[m] = block_ids
|
| 253 |
+
ids.append(block_ids)
|
| 254 |
+
offset += n
|
| 255 |
+
self.n_blocks = offset
|
| 256 |
+
|
| 257 |
+
self.predicted_mass = torch.ones(self.n_blocks, device=device)
|
| 258 |
+
self.prev_active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device)
|
| 259 |
+
self.active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device)
|
| 260 |
+
self.row_masks: Dict[nn.Linear, torch.Tensor] = {m: torch.zeros(m.weight.shape[0], dtype=torch.bool, device=device) for m in self.linear_modules}
|
| 261 |
+
|
| 262 |
+
def _topk_mask(self, values: torch.Tensor, fraction: float) -> torch.Tensor:
|
| 263 |
+
k = max(1, int(fraction * values.numel()))
|
| 264 |
+
mask = torch.zeros_like(values, dtype=torch.bool)
|
| 265 |
+
mask[torch.topk(values, k=k).indices] = True
|
| 266 |
+
return mask
|
| 267 |
+
|
| 268 |
+
@staticmethod
|
| 269 |
+
def _jaccard(a: torch.Tensor, b: torch.Tensor) -> float:
|
| 270 |
+
inter = (a & b).sum().float()
|
| 271 |
+
union = (a | b).sum().float()
|
| 272 |
+
return float((inter / torch.clamp(union, min=1.0)).item())
|
| 273 |
+
|
| 274 |
+
def _set_active(self, active: torch.Tensor) -> None:
|
| 275 |
+
self.active = active
|
| 276 |
+
self.row_masks = {}
|
| 277 |
+
for m, ids in self.module_to_ids.items():
|
| 278 |
+
self.row_masks[m] = active[ids]
|
| 279 |
+
|
| 280 |
+
def choose_pre_backward(self, step: int) -> None:
|
| 281 |
+
if step < self.warmup_steps:
|
| 282 |
+
self._set_active(torch.ones(self.n_blocks, dtype=torch.bool, device=self.device))
|
| 283 |
+
return
|
| 284 |
+
|
| 285 |
+
if self.policy == "oracle_current":
|
| 286 |
+
# Cannot select until after current gradients are known.
|
| 287 |
+
self._set_active(torch.zeros(self.n_blocks, dtype=torch.bool, device=self.device))
|
| 288 |
+
return
|
| 289 |
+
|
| 290 |
+
k_total = max(1, int(self.active_fraction * self.n_blocks))
|
| 291 |
+
|
| 292 |
+
if self.policy == "random":
|
| 293 |
+
active = torch.zeros(self.n_blocks, dtype=torch.bool, device=self.device)
|
| 294 |
+
active[torch.randperm(self.n_blocks, device=self.device)[:k_total]] = True
|
| 295 |
+
self._set_active(active)
|
| 296 |
+
return
|
| 297 |
+
|
| 298 |
+
if self.policy != "predicted_magnitude":
|
| 299 |
+
raise ValueError(f"Unknown policy: {self.policy}")
|
| 300 |
+
|
| 301 |
+
k_explore = min(k_total, max(0, int(self.explore_fraction * k_total)))
|
| 302 |
+
k_exploit = k_total - k_explore
|
| 303 |
+
active = torch.zeros(self.n_blocks, dtype=torch.bool, device=self.device)
|
| 304 |
+
|
| 305 |
+
scores = self.predicted_mass + 1e-8 * torch.rand_like(self.predicted_mass)
|
| 306 |
+
if k_exploit > 0:
|
| 307 |
+
active[torch.topk(scores, k=k_exploit).indices] = True
|
| 308 |
+
if k_explore > 0:
|
| 309 |
+
remaining = torch.nonzero(~active, as_tuple=False).flatten()
|
| 310 |
+
active[remaining[torch.randperm(remaining.numel(), device=self.device)[:k_explore]]] = True
|
| 311 |
+
self._set_active(active)
|
| 312 |
+
|
| 313 |
+
@torch.no_grad()
|
| 314 |
+
def current_gradient_mass(self) -> torch.Tensor:
|
| 315 |
+
mass = torch.zeros(self.n_blocks, device=self.device)
|
| 316 |
+
for m, ids in self.module_to_ids.items():
|
| 317 |
+
if m.weight.grad is None:
|
| 318 |
+
continue
|
| 319 |
+
row_sq = m.weight.grad.square().sum(dim=1)
|
| 320 |
+
if m.bias is not None and m.bias.grad is not None:
|
| 321 |
+
row_sq = row_sq + m.bias.grad.square()
|
| 322 |
+
mass[ids] = torch.sqrt(row_sq + 1e-30)
|
| 323 |
+
return mass
|
| 324 |
+
|
| 325 |
+
@torch.no_grad()
|
| 326 |
+
def audit_and_update(self, step: int) -> Dict[str, float]:
|
| 327 |
+
mass = self.current_gradient_mass()
|
| 328 |
+
|
| 329 |
+
if step < self.warmup_steps:
|
| 330 |
+
active = torch.ones(self.n_blocks, dtype=torch.bool, device=self.device)
|
| 331 |
+
self._set_active(active)
|
| 332 |
+
elif self.policy == "oracle_current":
|
| 333 |
+
active = self._topk_mask(mass, self.active_fraction)
|
| 334 |
+
self._set_active(active)
|
| 335 |
+
else:
|
| 336 |
+
active = self.active
|
| 337 |
+
|
| 338 |
+
true_sq = mass.square().sum()
|
| 339 |
+
approx_sq = mass[active].square().sum()
|
| 340 |
+
cosine = float((torch.sqrt(approx_sq + 1e-30) / torch.sqrt(true_sq + 1e-30)).item())
|
| 341 |
+
# With zero inactive blocks and active blocks using true gradient, cosine == norm ratio.
|
| 342 |
+
norm_ratio = cosine
|
| 343 |
+
|
| 344 |
+
oracle_mask = self._topk_mask(mass, self.active_fraction)
|
| 345 |
+
jacc = self._jaccard(active, oracle_mask)
|
| 346 |
+
stability = self._jaccard(active, self.prev_active)
|
| 347 |
+
self.prev_active = active.clone()
|
| 348 |
+
|
| 349 |
+
k20 = max(1, int(0.2 * self.n_blocks))
|
| 350 |
+
sorted_mass = torch.sort(mass, descending=True).values
|
| 351 |
+
top20_mass = float((sorted_mass[:k20].sum() / (sorted_mass.sum() + 1e-12)).item())
|
| 352 |
+
|
| 353 |
+
# Strict rule: do not update stats from inactive full gradients.
|
| 354 |
+
self.predicted_mass.mul_(self.unobserved_decay)
|
| 355 |
+
observed = active
|
| 356 |
+
self.predicted_mass[observed] = (
|
| 357 |
+
self.mass_beta * self.predicted_mass[observed]
|
| 358 |
+
+ (1.0 - self.mass_beta) * mass[observed]
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
return {
|
| 362 |
+
"cosine": cosine,
|
| 363 |
+
"norm_ratio": norm_ratio,
|
| 364 |
+
"top20_mass": top20_mass,
|
| 365 |
+
"jacc_oracle": jacc,
|
| 366 |
+
"stability": stability,
|
| 367 |
+
"active_fraction_real": float(active.float().mean().item()),
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
def row_mask_for(self, module: nn.Linear) -> Optional[torch.Tensor]:
|
| 371 |
+
return self.row_masks.get(module)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
# -----------------------------
|
| 375 |
+
# Masked Adam
|
| 376 |
+
# -----------------------------
|
| 377 |
+
|
| 378 |
+
class MaskedAdam:
|
| 379 |
+
def __init__(self, model: nn.Module, masker: Optional[RowMasker], lr: float, betas=(0.9, 0.95), eps=1e-8, weight_decay=0.0):
|
| 380 |
+
self.model = model
|
| 381 |
+
self.masker = masker
|
| 382 |
+
self.lr = lr
|
| 383 |
+
self.beta1, self.beta2 = betas
|
| 384 |
+
self.eps = eps
|
| 385 |
+
self.weight_decay = weight_decay
|
| 386 |
+
self.state: Dict[nn.Parameter, Dict[str, torch.Tensor]] = {}
|
| 387 |
+
self.linear_param: Dict[nn.Parameter, Tuple[nn.Linear, str]] = {}
|
| 388 |
+
for _, m in named_linear_modules(model):
|
| 389 |
+
self.linear_param[m.weight] = (m, "weight")
|
| 390 |
+
if m.bias is not None:
|
| 391 |
+
self.linear_param[m.bias] = (m, "bias")
|
| 392 |
+
|
| 393 |
+
def zero_grad(self) -> None:
|
| 394 |
+
for p in self.model.parameters():
|
| 395 |
+
p.grad = None
|
| 396 |
+
|
| 397 |
+
@torch.no_grad()
|
| 398 |
+
def step(self) -> None:
|
| 399 |
+
for p in self.model.parameters():
|
| 400 |
+
if p.grad is None:
|
| 401 |
+
continue
|
| 402 |
+
if p not in self.state:
|
| 403 |
+
self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
|
| 404 |
+
m = self.state[p]["m"]
|
| 405 |
+
v = self.state[p]["v"]
|
| 406 |
+
g = p.grad
|
| 407 |
+
if self.weight_decay:
|
| 408 |
+
g = g.add(p, alpha=self.weight_decay)
|
| 409 |
+
|
| 410 |
+
row_mask = None
|
| 411 |
+
if self.masker is not None and p in self.linear_param:
|
| 412 |
+
module, kind = self.linear_param[p]
|
| 413 |
+
base = self.masker.row_mask_for(module)
|
| 414 |
+
if base is not None:
|
| 415 |
+
row_mask = base.view(-1, *([1] * (p.ndim - 1))) if kind == "weight" else base
|
| 416 |
+
|
| 417 |
+
if row_mask is None:
|
| 418 |
+
m.mul_(self.beta1).add_(g, alpha=1.0 - self.beta1)
|
| 419 |
+
v.mul_(self.beta2).addcmul_(g, g, value=1.0 - self.beta2)
|
| 420 |
+
p.add_(m / (torch.sqrt(v) + self.eps), alpha=-self.lr)
|
| 421 |
+
else:
|
| 422 |
+
mask = row_mask.expand_as(p)
|
| 423 |
+
if not bool(mask.any().item()):
|
| 424 |
+
continue
|
| 425 |
+
new_m = self.beta1 * m + (1.0 - self.beta1) * g
|
| 426 |
+
new_v = self.beta2 * v + (1.0 - self.beta2) * g * g
|
| 427 |
+
m[mask] = new_m[mask]
|
| 428 |
+
v[mask] = new_v[mask]
|
| 429 |
+
update = m / (torch.sqrt(v) + self.eps)
|
| 430 |
+
p[mask] = p[mask] - self.lr * update[mask]
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
# -----------------------------
|
| 434 |
+
# Training
|
| 435 |
+
# -----------------------------
|
| 436 |
+
|
| 437 |
+
@torch.no_grad()
|
| 438 |
+
def estimate_loss(model: nn.Module, corpus: CharCorpus, batch_size: int, eval_iters: int) -> Dict[str, float]:
|
| 439 |
+
model.eval()
|
| 440 |
+
out = {}
|
| 441 |
+
for split in ["train", "val"]:
|
| 442 |
+
losses = []
|
| 443 |
+
for _ in range(eval_iters):
|
| 444 |
+
x, y = corpus.get_batch(split, batch_size)
|
| 445 |
+
_, loss = model(x, y)
|
| 446 |
+
losses.append(float(loss.item()))
|
| 447 |
+
out[split] = sum(losses) / len(losses)
|
| 448 |
+
model.train()
|
| 449 |
+
return out
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def train_run(corpus: CharCorpus, args: argparse.Namespace, policy: Optional[Policy], active_fraction: float, seed_offset: int) -> Dict[str, float | str]:
|
| 453 |
+
set_seed(args.seed + seed_offset)
|
| 454 |
+
dev = corpus.device
|
| 455 |
+
model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev)
|
| 456 |
+
|
| 457 |
+
masker = None
|
| 458 |
+
if policy is not None:
|
| 459 |
+
masker = RowMasker(
|
| 460 |
+
model=model,
|
| 461 |
+
policy=policy,
|
| 462 |
+
active_fraction=active_fraction,
|
| 463 |
+
explore_fraction=args.explore_fraction,
|
| 464 |
+
mass_beta=args.mass_beta,
|
| 465 |
+
unobserved_decay=args.unobserved_decay,
|
| 466 |
+
warmup_steps=args.warmup_steps,
|
| 467 |
+
device=dev,
|
| 468 |
+
)
|
| 469 |
+
opt = MaskedAdam(model, masker, lr=args.lr, weight_decay=args.weight_decay)
|
| 470 |
+
|
| 471 |
+
sums = {"cosine": 0.0, "norm_ratio": 0.0, "top20_mass": 0.0, "jacc_oracle": 0.0, "stability": 0.0, "active_fraction_real": 0.0}
|
| 472 |
+
count = 0
|
| 473 |
+
|
| 474 |
+
for step in range(args.steps):
|
| 475 |
+
x, y = corpus.get_batch("train", args.batch_size)
|
| 476 |
+
if masker is not None:
|
| 477 |
+
masker.choose_pre_backward(step)
|
| 478 |
+
_, loss = model(x, y)
|
| 479 |
+
opt.zero_grad()
|
| 480 |
+
loss.backward()
|
| 481 |
+
if masker is not None:
|
| 482 |
+
metrics = masker.audit_and_update(step)
|
| 483 |
+
if step >= args.warmup_steps:
|
| 484 |
+
for k in sums:
|
| 485 |
+
sums[k] += metrics[k]
|
| 486 |
+
count += 1
|
| 487 |
+
opt.step()
|
| 488 |
+
|
| 489 |
+
if args.verbose and (step % args.eval_interval == 0 or step == args.steps - 1):
|
| 490 |
+
losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters)
|
| 491 |
+
name = "dense" if policy is None else policy
|
| 492 |
+
print(f"{name:20s} step={step:5d} train={losses['train']:.4f} val={losses['val']:.4f}")
|
| 493 |
+
|
| 494 |
+
losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters)
|
| 495 |
+
row: Dict[str, float | str] = {
|
| 496 |
+
"run": "dense_baseline" if policy is None else policy,
|
| 497 |
+
"target_active": 1.0 if policy is None else active_fraction,
|
| 498 |
+
"train_loss": losses["train"],
|
| 499 |
+
"val_loss": losses["val"],
|
| 500 |
+
}
|
| 501 |
+
if masker is None or count == 0:
|
| 502 |
+
row.update({"cosine": float("nan"), "norm_ratio": float("nan"), "top20_mass": float("nan"), "jacc_oracle": float("nan"), "stability": float("nan"), "active_fraction_real": 1.0})
|
| 503 |
+
else:
|
| 504 |
+
for k, v in sums.items():
|
| 505 |
+
row[k] = v / count
|
| 506 |
+
return row
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def print_summary(rows: List[Dict[str, float | str]]) -> None:
|
| 510 |
+
print("\nSummary")
|
| 511 |
+
header = f"{'run':>22s} {'target':>7s} {'actual':>7s} {'val':>8s} {'train':>8s} {'cos':>7s} {'top20':>7s} {'jacc':>7s} {'stable':>7s}"
|
| 512 |
+
print(header)
|
| 513 |
+
print("-" * len(header))
|
| 514 |
+
for r in rows:
|
| 515 |
+
print(
|
| 516 |
+
f"{str(r['run']):>22s} "
|
| 517 |
+
f"{float(r['target_active']):7.3f} "
|
| 518 |
+
f"{float(r['active_fraction_real']):7.3f} "
|
| 519 |
+
f"{float(r['val_loss']):8.4f} "
|
| 520 |
+
f"{float(r['train_loss']):8.4f} "
|
| 521 |
+
f"{float(r['cosine']):7.3f} "
|
| 522 |
+
f"{float(r['top20_mass']):7.3f} "
|
| 523 |
+
f"{float(r['jacc_oracle']):7.3f} "
|
| 524 |
+
f"{float(r['stability']):7.3f}"
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def parse_args() -> argparse.Namespace:
|
| 529 |
+
p = argparse.ArgumentParser()
|
| 530 |
+
p.add_argument("--text_path", type=str, default=None)
|
| 531 |
+
p.add_argument("--synthetic_sentences", type=int, default=12000)
|
| 532 |
+
p.add_argument("--steps", type=int, default=1000)
|
| 533 |
+
p.add_argument("--quick", action="store_true")
|
| 534 |
+
p.add_argument("--batch_size", type=int, default=32)
|
| 535 |
+
p.add_argument("--block_size", type=int, default=64)
|
| 536 |
+
p.add_argument("--n_layer", type=int, default=2)
|
| 537 |
+
p.add_argument("--n_head", type=int, default=4)
|
| 538 |
+
p.add_argument("--n_embd", type=int, default=64)
|
| 539 |
+
p.add_argument("--dropout", type=float, default=0.0)
|
| 540 |
+
p.add_argument("--lr", type=float, default=3e-4)
|
| 541 |
+
p.add_argument("--weight_decay", type=float, default=0.0)
|
| 542 |
+
p.add_argument("--active_fractions", type=float, nargs="+", default=[0.10, 0.05, 0.02])
|
| 543 |
+
p.add_argument("--explore_fraction", type=float, default=0.10)
|
| 544 |
+
p.add_argument("--mass_beta", type=float, default=0.95)
|
| 545 |
+
p.add_argument("--unobserved_decay", type=float, default=0.999)
|
| 546 |
+
p.add_argument("--warmup_steps", type=int, default=0)
|
| 547 |
+
p.add_argument("--eval_interval", type=int, default=200)
|
| 548 |
+
p.add_argument("--eval_iters", type=int, default=20)
|
| 549 |
+
p.add_argument("--seed", type=int, default=7)
|
| 550 |
+
p.add_argument("--verbose", action="store_true")
|
| 551 |
+
return p.parse_args()
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
def main() -> None:
|
| 555 |
+
args = parse_args()
|
| 556 |
+
if args.quick:
|
| 557 |
+
args.steps = 60
|
| 558 |
+
args.eval_iters = 3
|
| 559 |
+
args.batch_size = 16
|
| 560 |
+
args.block_size = 32
|
| 561 |
+
args.n_layer = 1
|
| 562 |
+
args.n_embd = 32
|
| 563 |
+
args.n_head = 4
|
| 564 |
+
args.synthetic_sentences = 2000
|
| 565 |
+
args.active_fractions = [0.10, 0.02]
|
| 566 |
+
|
| 567 |
+
set_seed(args.seed)
|
| 568 |
+
dev = device()
|
| 569 |
+
print(f"device={dev}")
|
| 570 |
+
corpus = CharCorpus(load_text(args), args.block_size, dev)
|
| 571 |
+
print(f"vocab_size={corpus.vocab_size} train_tokens={len(corpus.train_data)} val_tokens={len(corpus.val_data)}")
|
| 572 |
+
print(f"warmup_steps={args.warmup_steps} explore_fraction={args.explore_fraction}")
|
| 573 |
+
|
| 574 |
+
rows: List[Dict[str, float | str]] = []
|
| 575 |
+
print("\nRunning dense baseline")
|
| 576 |
+
rows.append(train_run(corpus, args, policy=None, active_fraction=1.0, seed_offset=0))
|
| 577 |
+
|
| 578 |
+
seed_offset = 100
|
| 579 |
+
for af in args.active_fractions:
|
| 580 |
+
for policy in ["oracle_current", "predicted_magnitude", "random"]:
|
| 581 |
+
print(f"\nRunning policy={policy}, active_fraction={af:.3f}")
|
| 582 |
+
rows.append(train_run(corpus, args, policy=policy, active_fraction=af, seed_offset=seed_offset))
|
| 583 |
+
seed_offset += 1
|
| 584 |
+
|
| 585 |
+
print_summary(rows)
|
| 586 |
+
|
| 587 |
+
print("\nNotes")
|
| 588 |
+
print(" oracle_current uses the current full gradient to choose rows; it is an upper bound, not a practical selector.")
|
| 589 |
+
print(" predicted_magnitude chooses from EMA mass only, plus a small random exploration budget.")
|
| 590 |
+
print(" EMA mass is updated only for active/observed rows, not all rows.")
|
| 591 |
+
print(" inactive Linear rows are frozen by MaskedAdam, including Adam state; zero grad alone is not enough.")
|
| 592 |
+
print(" dense gradients are still computed for audit, so this is not a wall-clock speed benchmark yet.")
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
if __name__ == "__main__":
|
| 596 |
+
main()
|
experiments/sparse_transformer_v7.py
ADDED
|
@@ -0,0 +1,780 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sparse Transformer v7: discovery stress tests for stable gradient-support masks.
|
| 3 |
+
|
| 4 |
+
This version follows the v6 result:
|
| 5 |
+
|
| 6 |
+
oracle_current works far better than random, so useful sparse support exists;
|
| 7 |
+
predicted_magnitude without warmup does not reliably discover that support.
|
| 8 |
+
|
| 9 |
+
v7 focuses on discovery mechanisms:
|
| 10 |
+
|
| 11 |
+
1. predicted_magnitude
|
| 12 |
+
Exploit rows with the largest EMA-observed gradient mass.
|
| 13 |
+
|
| 14 |
+
2. ucb_magnitude
|
| 15 |
+
A bandit-style selector: EMA mass + an uncertainty bonus for under-observed rows.
|
| 16 |
+
This is meant to discover useful rows without dense refresh.
|
| 17 |
+
|
| 18 |
+
First observation initializes EMA scale immediately.
|
| 19 |
+
|
| 20 |
+
3. stale_current
|
| 21 |
+
A renamed diagnostic control: use the previous full-gradient mass. It is not
|
| 22 |
+
practical because it relies on dense audit gradients, but it tells us whether
|
| 23 |
+
one-step lag is too noisy.
|
| 24 |
+
|
| 25 |
+
4. oracle_current
|
| 26 |
+
True current top-k by dense gradient mass. Upper bound only.
|
| 27 |
+
|
| 28 |
+
5. random
|
| 29 |
+
Control.
|
| 30 |
+
|
| 31 |
+
Important limitation
|
| 32 |
+
--------------------
|
| 33 |
+
This still calls loss.backward(), so PyTorch computes dense gradients. Dense
|
| 34 |
+
current gradients are used for audit metrics and for oracle/stale controls.
|
| 35 |
+
The practical selectors only update their EMA statistics from active rows.
|
| 36 |
+
Actual speedup would require structured partial backward/custom kernels.
|
| 37 |
+
|
| 38 |
+
Example runs
|
| 39 |
+
------------
|
| 40 |
+
Smoke test:
|
| 41 |
+
python3 sparse_transformer_v7.py --quick
|
| 42 |
+
|
| 43 |
+
No-warmup discovery test:
|
| 44 |
+
python3 sparse_transformer_v7.py --steps 1000 \
|
| 45 |
+
--active_fractions 0.10 0.05 0.02 \
|
| 46 |
+
--policies predicted_magnitude ucb_magnitude oracle_current random \
|
| 47 |
+
--warmup_steps_list 0 5 50 --explore_fractions 0.10 0.30
|
| 48 |
+
|
| 49 |
+
Warm-start separation test:
|
| 50 |
+
python3 sparse_transformer_v7.py --steps 1000 \
|
| 51 |
+
--active_fractions 0.10 0.05 0.02 \
|
| 52 |
+
--policies predicted_magnitude ucb_magnitude oracle_current random \
|
| 53 |
+
--warmup_steps_list 0 5 50 200 --explore_fractions 0.10
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
from __future__ import annotations
|
| 57 |
+
|
| 58 |
+
import argparse
|
| 59 |
+
import math
|
| 60 |
+
import random
|
| 61 |
+
from typing import Dict, List, Literal, Optional, Tuple
|
| 62 |
+
|
| 63 |
+
import torch
|
| 64 |
+
|
| 65 |
+
torch.set_num_threads(1)
|
| 66 |
+
import torch.nn as nn
|
| 67 |
+
import torch.nn.functional as F
|
| 68 |
+
|
| 69 |
+
Policy = Literal["predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def set_seed(seed: int) -> None:
|
| 73 |
+
random.seed(seed)
|
| 74 |
+
torch.manual_seed(seed)
|
| 75 |
+
if torch.cuda.is_available():
|
| 76 |
+
torch.cuda.manual_seed_all(seed)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def device() -> str:
|
| 80 |
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# -----------------------------
|
| 84 |
+
# Data
|
| 85 |
+
# -----------------------------
|
| 86 |
+
|
| 87 |
+
def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
|
| 88 |
+
rng = random.Random(seed)
|
| 89 |
+
names = ["ada", "turing", "grace", "lovelace", "noether", "shannon", "hopper", "gauss"]
|
| 90 |
+
verbs = ["builds", "tests", "traces", "compresses", "predicts", "routes", "writes", "measures"]
|
| 91 |
+
objects = ["signals", "gradients", "tokens", "circuits", "features", "masks", "errors", "states"]
|
| 92 |
+
adverbs = ["quietly", "boldly", "slowly", "quickly", "cleanly", "strangely", "carefully"]
|
| 93 |
+
clauses = [
|
| 94 |
+
"when the loss falls",
|
| 95 |
+
"after the mask shifts",
|
| 96 |
+
"before the model answers",
|
| 97 |
+
"while the signal drifts",
|
| 98 |
+
"if the pattern repeats",
|
| 99 |
+
"because the tail is noisy",
|
| 100 |
+
]
|
| 101 |
+
symbols = ["alpha", "beta", "gamma", "delta", "omega", "sigma"]
|
| 102 |
+
|
| 103 |
+
lines: List[str] = []
|
| 104 |
+
for _ in range(n_sentences):
|
| 105 |
+
t = rng.randrange(6)
|
| 106 |
+
if t == 0:
|
| 107 |
+
line = f"{rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} {rng.choice(adverbs)}."
|
| 108 |
+
elif t == 1:
|
| 109 |
+
line = f"{rng.choice(clauses)}, {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)}."
|
| 110 |
+
elif t == 2:
|
| 111 |
+
a, b = rng.sample(symbols, 2)
|
| 112 |
+
line = f"rule {a}: {rng.choice(objects)} -> {rng.choice(objects)}; rule {b}: {rng.choice(objects)} -> {rng.choice(objects)}."
|
| 113 |
+
elif t == 3:
|
| 114 |
+
line = f"the {rng.choice(objects)} {rng.choice(verbs)} the {rng.choice(objects)} {rng.choice(adverbs)}."
|
| 115 |
+
elif t == 4:
|
| 116 |
+
seq = " ".join(rng.choice(symbols) for _ in range(rng.randint(2, 7)))
|
| 117 |
+
line = f"sequence {seq} ends when {rng.choice(names)} {rng.choice(verbs)}."
|
| 118 |
+
else:
|
| 119 |
+
line = f"if {rng.choice(objects)} rise then {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} else wait."
|
| 120 |
+
lines.append(line)
|
| 121 |
+
return "\n".join(lines) + "\n"
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class CharCorpus:
|
| 125 |
+
def __init__(self, text: str, block_size: int, device: str):
|
| 126 |
+
chars = sorted(set(text))
|
| 127 |
+
self.stoi = {ch: i for i, ch in enumerate(chars)}
|
| 128 |
+
self.itos = {i: ch for ch, i in self.stoi.items()}
|
| 129 |
+
self.vocab_size = len(chars)
|
| 130 |
+
self.block_size = block_size
|
| 131 |
+
self.device = device
|
| 132 |
+
|
| 133 |
+
data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
|
| 134 |
+
split = int(0.9 * len(data))
|
| 135 |
+
self.train_data = data[:split]
|
| 136 |
+
self.val_data = data[split:]
|
| 137 |
+
|
| 138 |
+
def get_batch(self, split: str, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 139 |
+
data = self.train_data if split == "train" else self.val_data
|
| 140 |
+
max_start = len(data) - self.block_size - 1
|
| 141 |
+
if max_start <= 0:
|
| 142 |
+
raise ValueError("Corpus too small for block_size")
|
| 143 |
+
ix = torch.randint(max_start, (batch_size,))
|
| 144 |
+
x = torch.stack([data[i : i + self.block_size] for i in ix])
|
| 145 |
+
y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
|
| 146 |
+
return x.to(self.device), y.to(self.device)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def load_text(args: argparse.Namespace) -> str:
|
| 150 |
+
if args.text_path:
|
| 151 |
+
with open(args.text_path, "r", encoding="utf-8") as f:
|
| 152 |
+
return f.read()
|
| 153 |
+
return make_synthetic_corpus(args.synthetic_sentences, args.seed)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# -----------------------------
|
| 157 |
+
# Mini GPT
|
| 158 |
+
# -----------------------------
|
| 159 |
+
|
| 160 |
+
class CausalSelfAttention(nn.Module):
|
| 161 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 162 |
+
super().__init__()
|
| 163 |
+
assert n_embd % n_head == 0
|
| 164 |
+
self.n_head = n_head
|
| 165 |
+
self.head_dim = n_embd // n_head
|
| 166 |
+
self.c_attn = nn.Linear(n_embd, 3 * n_embd)
|
| 167 |
+
self.c_proj = nn.Linear(n_embd, n_embd)
|
| 168 |
+
self.dropout = nn.Dropout(dropout)
|
| 169 |
+
self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
|
| 170 |
+
|
| 171 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 172 |
+
B, T, C = x.shape
|
| 173 |
+
qkv = self.c_attn(x)
|
| 174 |
+
q, k, v = qkv.split(C, dim=2)
|
| 175 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 176 |
+
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 177 |
+
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 178 |
+
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 179 |
+
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
|
| 180 |
+
att = F.softmax(att, dim=-1)
|
| 181 |
+
att = self.dropout(att)
|
| 182 |
+
y = att @ v
|
| 183 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 184 |
+
return self.c_proj(y)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class FeedForward(nn.Module):
|
| 188 |
+
def __init__(self, n_embd: int, dropout: float):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.c_fc = nn.Linear(n_embd, 4 * n_embd)
|
| 191 |
+
self.c_proj = nn.Linear(4 * n_embd, n_embd)
|
| 192 |
+
self.dropout = nn.Dropout(dropout)
|
| 193 |
+
|
| 194 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 195 |
+
return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class Block(nn.Module):
|
| 199 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 200 |
+
super().__init__()
|
| 201 |
+
self.ln1 = nn.LayerNorm(n_embd)
|
| 202 |
+
self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
|
| 203 |
+
self.ln2 = nn.LayerNorm(n_embd)
|
| 204 |
+
self.mlp = FeedForward(n_embd, dropout)
|
| 205 |
+
|
| 206 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 207 |
+
x = x + self.attn(self.ln1(x))
|
| 208 |
+
x = x + self.mlp(self.ln2(x))
|
| 209 |
+
return x
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class MiniGPT(nn.Module):
|
| 213 |
+
def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float):
|
| 214 |
+
super().__init__()
|
| 215 |
+
self.block_size = block_size
|
| 216 |
+
self.tok_emb = nn.Embedding(vocab_size, n_embd)
|
| 217 |
+
self.pos_emb = nn.Embedding(block_size, n_embd)
|
| 218 |
+
self.drop = nn.Dropout(dropout)
|
| 219 |
+
self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
|
| 220 |
+
self.ln_f = nn.LayerNorm(n_embd)
|
| 221 |
+
self.lm_head = nn.Linear(n_embd, vocab_size)
|
| 222 |
+
|
| 223 |
+
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
|
| 224 |
+
B, T = idx.shape
|
| 225 |
+
pos = torch.arange(T, device=idx.device)
|
| 226 |
+
x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
|
| 227 |
+
x = self.drop(x)
|
| 228 |
+
x = self.blocks(x)
|
| 229 |
+
x = self.ln_f(x)
|
| 230 |
+
logits = self.lm_head(x)
|
| 231 |
+
loss = None
|
| 232 |
+
if targets is not None:
|
| 233 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
| 234 |
+
return logits, loss
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def named_linear_modules(model: nn.Module) -> List[Tuple[str, nn.Linear]]:
|
| 238 |
+
return [(name, m) for name, m in model.named_modules() if isinstance(m, nn.Linear)]
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def parameter_fractions(model: nn.Module) -> Tuple[int, int, float]:
|
| 242 |
+
total = sum(p.numel() for p in model.parameters())
|
| 243 |
+
linear = 0
|
| 244 |
+
for _, m in named_linear_modules(model):
|
| 245 |
+
linear += m.weight.numel()
|
| 246 |
+
if m.bias is not None:
|
| 247 |
+
linear += m.bias.numel()
|
| 248 |
+
return total, linear, linear / max(1, total)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
# -----------------------------
|
| 252 |
+
# Mask selector
|
| 253 |
+
# -----------------------------
|
| 254 |
+
|
| 255 |
+
class RowMasker:
|
| 256 |
+
def __init__(
|
| 257 |
+
self,
|
| 258 |
+
model: nn.Module,
|
| 259 |
+
policy: Policy,
|
| 260 |
+
active_fraction: float,
|
| 261 |
+
explore_fraction: float,
|
| 262 |
+
mass_beta: float,
|
| 263 |
+
unobserved_decay: float,
|
| 264 |
+
warmup_steps: int,
|
| 265 |
+
ucb_alpha: float,
|
| 266 |
+
mass_init: float,
|
| 267 |
+
device: str,
|
| 268 |
+
):
|
| 269 |
+
self.model = model
|
| 270 |
+
self.policy = policy
|
| 271 |
+
self.active_fraction = active_fraction
|
| 272 |
+
self.explore_fraction = explore_fraction
|
| 273 |
+
self.mass_beta = mass_beta
|
| 274 |
+
self.unobserved_decay = unobserved_decay
|
| 275 |
+
self.warmup_steps = warmup_steps
|
| 276 |
+
self.ucb_alpha = ucb_alpha
|
| 277 |
+
self.mass_init = mass_init
|
| 278 |
+
self.device = device
|
| 279 |
+
self.step_index = 0
|
| 280 |
+
|
| 281 |
+
self.linear_modules = [m for _, m in named_linear_modules(model)]
|
| 282 |
+
self.module_to_ids: Dict[nn.Linear, torch.Tensor] = {}
|
| 283 |
+
ids = []
|
| 284 |
+
offset = 0
|
| 285 |
+
for m in self.linear_modules:
|
| 286 |
+
n = m.weight.shape[0]
|
| 287 |
+
block_ids = torch.arange(offset, offset + n, device=device)
|
| 288 |
+
self.module_to_ids[m] = block_ids
|
| 289 |
+
ids.append(block_ids)
|
| 290 |
+
offset += n
|
| 291 |
+
self.n_blocks = offset
|
| 292 |
+
|
| 293 |
+
self.predicted_mass = torch.full((self.n_blocks,), mass_init, device=device)
|
| 294 |
+
self.last_full_mass = torch.full((self.n_blocks,), mass_init, device=device)
|
| 295 |
+
self.observed_count = torch.zeros(self.n_blocks, device=device)
|
| 296 |
+
self.global_mass_ema = torch.tensor(max(mass_init, 1e-6), device=device)
|
| 297 |
+
|
| 298 |
+
self.prev_active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device)
|
| 299 |
+
self.active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device)
|
| 300 |
+
self.row_masks: Dict[nn.Linear, torch.Tensor] = {
|
| 301 |
+
m: torch.zeros(m.weight.shape[0], dtype=torch.bool, device=device) for m in self.linear_modules
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
def _topk_mask(self, values: torch.Tensor, fraction: float) -> torch.Tensor:
|
| 305 |
+
k = max(1, int(fraction * values.numel()))
|
| 306 |
+
mask = torch.zeros_like(values, dtype=torch.bool)
|
| 307 |
+
# Tie-breaking noise matters when many rows have identical initial scores.
|
| 308 |
+
noisy = values + 1e-9 * torch.rand_like(values)
|
| 309 |
+
mask[torch.topk(noisy, k=k).indices] = True
|
| 310 |
+
return mask
|
| 311 |
+
|
| 312 |
+
@staticmethod
|
| 313 |
+
def _jaccard(a: torch.Tensor, b: torch.Tensor) -> float:
|
| 314 |
+
inter = (a & b).sum().float()
|
| 315 |
+
union = (a | b).sum().float()
|
| 316 |
+
return float((inter / torch.clamp(union, min=1.0)).item())
|
| 317 |
+
|
| 318 |
+
def _set_active(self, active: torch.Tensor) -> None:
|
| 319 |
+
self.active = active
|
| 320 |
+
self.row_masks = {}
|
| 321 |
+
for m, ids in self.module_to_ids.items():
|
| 322 |
+
self.row_masks[m] = active[ids]
|
| 323 |
+
|
| 324 |
+
def _sample_exploit_explore(self, scores: torch.Tensor) -> torch.Tensor:
|
| 325 |
+
n = self.n_blocks
|
| 326 |
+
k_total = max(1, int(self.active_fraction * n))
|
| 327 |
+
k_explore = min(k_total, max(0, int(self.explore_fraction * k_total)))
|
| 328 |
+
k_exploit = k_total - k_explore
|
| 329 |
+
active = torch.zeros(n, dtype=torch.bool, device=self.device)
|
| 330 |
+
|
| 331 |
+
if k_exploit > 0:
|
| 332 |
+
active[torch.topk(scores + 1e-9 * torch.rand_like(scores), k=k_exploit).indices] = True
|
| 333 |
+
if k_explore > 0:
|
| 334 |
+
remaining = torch.nonzero(~active, as_tuple=False).flatten()
|
| 335 |
+
pick = remaining[torch.randperm(remaining.numel(), device=self.device)[:k_explore]]
|
| 336 |
+
active[pick] = True
|
| 337 |
+
return active
|
| 338 |
+
|
| 339 |
+
def choose_pre_backward(self, step: int) -> None:
|
| 340 |
+
self.step_index = step
|
| 341 |
+
if step < self.warmup_steps:
|
| 342 |
+
self._set_active(torch.ones(self.n_blocks, dtype=torch.bool, device=self.device))
|
| 343 |
+
return
|
| 344 |
+
|
| 345 |
+
if self.policy == "oracle_current":
|
| 346 |
+
# Cannot select until after current gradients are known.
|
| 347 |
+
self._set_active(torch.zeros(self.n_blocks, dtype=torch.bool, device=self.device))
|
| 348 |
+
return
|
| 349 |
+
|
| 350 |
+
if self.policy == "random":
|
| 351 |
+
self._set_active(self._sample_exploit_explore(torch.rand(self.n_blocks, device=self.device)))
|
| 352 |
+
return
|
| 353 |
+
|
| 354 |
+
if self.policy == "stale_current":
|
| 355 |
+
self._set_active(self._topk_mask(self.last_full_mass, self.active_fraction))
|
| 356 |
+
return
|
| 357 |
+
|
| 358 |
+
if self.policy == "predicted_magnitude":
|
| 359 |
+
self._set_active(self._sample_exploit_explore(self.predicted_mass))
|
| 360 |
+
return
|
| 361 |
+
|
| 362 |
+
if self.policy == "ucb_magnitude":
|
| 363 |
+
t = max(1, step - self.warmup_steps + 1)
|
| 364 |
+
log_term = torch.log(torch.tensor(float(t + 2), device=self.device))
|
| 365 |
+
bonus_scale = torch.clamp(self.global_mass_ema, min=1e-8)
|
| 366 |
+
bonus = self.ucb_alpha * bonus_scale * torch.sqrt(log_term / (self.observed_count + 1.0))
|
| 367 |
+
scores = self.predicted_mass + bonus
|
| 368 |
+
self._set_active(self._sample_exploit_explore(scores))
|
| 369 |
+
return
|
| 370 |
+
|
| 371 |
+
raise ValueError(f"Unknown policy: {self.policy}")
|
| 372 |
+
|
| 373 |
+
@torch.no_grad()
|
| 374 |
+
def current_gradient_mass(self) -> torch.Tensor:
|
| 375 |
+
mass = torch.zeros(self.n_blocks, device=self.device)
|
| 376 |
+
for m, ids in self.module_to_ids.items():
|
| 377 |
+
if m.weight.grad is None:
|
| 378 |
+
continue
|
| 379 |
+
row_sq = m.weight.grad.square().sum(dim=1)
|
| 380 |
+
if m.bias is not None and m.bias.grad is not None:
|
| 381 |
+
row_sq = row_sq + m.bias.grad.square()
|
| 382 |
+
mass[ids] = torch.sqrt(row_sq + 1e-30)
|
| 383 |
+
return mass
|
| 384 |
+
|
| 385 |
+
@torch.no_grad()
|
| 386 |
+
def audit_and_update(self, step: int) -> Dict[str, float]:
|
| 387 |
+
mass = self.current_gradient_mass()
|
| 388 |
+
|
| 389 |
+
if step < self.warmup_steps:
|
| 390 |
+
active = torch.ones(self.n_blocks, dtype=torch.bool, device=self.device)
|
| 391 |
+
self._set_active(active)
|
| 392 |
+
elif self.policy == "oracle_current":
|
| 393 |
+
active = self._topk_mask(mass, self.active_fraction)
|
| 394 |
+
self._set_active(active)
|
| 395 |
+
else:
|
| 396 |
+
active = self.active
|
| 397 |
+
|
| 398 |
+
true_sq = mass.square().sum()
|
| 399 |
+
approx_sq = mass[active].square().sum()
|
| 400 |
+
cosine = float((torch.sqrt(approx_sq + 1e-30) / torch.sqrt(true_sq + 1e-30)).item())
|
| 401 |
+
norm_ratio = cosine
|
| 402 |
+
|
| 403 |
+
oracle_mask = self._topk_mask(mass, self.active_fraction)
|
| 404 |
+
jacc = self._jaccard(active, oracle_mask)
|
| 405 |
+
stability = self._jaccard(active, self.prev_active)
|
| 406 |
+
self.prev_active = active.clone()
|
| 407 |
+
|
| 408 |
+
k20 = max(1, int(0.2 * self.n_blocks))
|
| 409 |
+
sorted_mass = torch.sort(mass, descending=True).values
|
| 410 |
+
top20_mass = float((sorted_mass[:k20].sum() / (sorted_mass.sum() + 1e-12)).item())
|
| 411 |
+
|
| 412 |
+
new_active = active & (self.observed_count == 0)
|
| 413 |
+
|
| 414 |
+
# Strict rule for practical policies: update stats only from active rows.
|
| 415 |
+
# oracle_current and stale_current also update only active rows for consistency;
|
| 416 |
+
# stale_current separately records last_full_mass as a diagnostic signal.
|
| 417 |
+
self.predicted_mass.mul_(self.unobserved_decay)
|
| 418 |
+
observed = active
|
| 419 |
+
if bool(observed.any().item()):
|
| 420 |
+
obs_mass = mass[observed]
|
| 421 |
+
first_seen = self.observed_count[observed] == 0
|
| 422 |
+
ema_mass = self.mass_beta * self.predicted_mass[observed] + (1.0 - self.mass_beta) * obs_mass
|
| 423 |
+
# First observation should establish the real scale immediately.
|
| 424 |
+
# Otherwise a beta=0.95 EMA needs many observations to climb from zero.
|
| 425 |
+
self.predicted_mass[observed] = torch.where(first_seen, obs_mass, ema_mass)
|
| 426 |
+
self.observed_count[observed] += 1.0
|
| 427 |
+
self.global_mass_ema = self.mass_beta * self.global_mass_ema + (1.0 - self.mass_beta) * obs_mass.mean()
|
| 428 |
+
|
| 429 |
+
# Dense audit signal. Only stale_current is allowed to use this for next-step selection.
|
| 430 |
+
self.last_full_mass = mass.detach().clone()
|
| 431 |
+
|
| 432 |
+
coverage = float((self.observed_count > 0).float().mean().item())
|
| 433 |
+
avg_obs_count = float(self.observed_count.mean().item())
|
| 434 |
+
new_active_fraction = float((new_active.float().mean()).item())
|
| 435 |
+
|
| 436 |
+
return {
|
| 437 |
+
"cosine": cosine,
|
| 438 |
+
"norm_ratio": norm_ratio,
|
| 439 |
+
"top20_mass": top20_mass,
|
| 440 |
+
"jacc_oracle": jacc,
|
| 441 |
+
"stability": stability,
|
| 442 |
+
"active_fraction_real": float(active.float().mean().item()),
|
| 443 |
+
"coverage": coverage,
|
| 444 |
+
"avg_obs_count": avg_obs_count,
|
| 445 |
+
"new_active_fraction": new_active_fraction,
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
def row_mask_for(self, module: nn.Linear) -> Optional[torch.Tensor]:
|
| 449 |
+
return self.row_masks.get(module)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
# -----------------------------
|
| 453 |
+
# Masked Adam
|
| 454 |
+
# -----------------------------
|
| 455 |
+
|
| 456 |
+
class MaskedAdam:
|
| 457 |
+
def __init__(
|
| 458 |
+
self,
|
| 459 |
+
model: nn.Module,
|
| 460 |
+
masker: Optional[RowMasker],
|
| 461 |
+
lr: float,
|
| 462 |
+
betas=(0.9, 0.95),
|
| 463 |
+
eps=1e-8,
|
| 464 |
+
weight_decay=0.0,
|
| 465 |
+
freeze_non_linear_when_sparse: bool = False,
|
| 466 |
+
):
|
| 467 |
+
self.model = model
|
| 468 |
+
self.masker = masker
|
| 469 |
+
self.lr = lr
|
| 470 |
+
self.beta1, self.beta2 = betas
|
| 471 |
+
self.eps = eps
|
| 472 |
+
self.weight_decay = weight_decay
|
| 473 |
+
self.freeze_non_linear_when_sparse = freeze_non_linear_when_sparse
|
| 474 |
+
self.state: Dict[nn.Parameter, Dict[str, torch.Tensor]] = {}
|
| 475 |
+
self.linear_param: Dict[nn.Parameter, Tuple[nn.Linear, str]] = {}
|
| 476 |
+
for _, m in named_linear_modules(model):
|
| 477 |
+
self.linear_param[m.weight] = (m, "weight")
|
| 478 |
+
if m.bias is not None:
|
| 479 |
+
self.linear_param[m.bias] = (m, "bias")
|
| 480 |
+
|
| 481 |
+
def zero_grad(self) -> None:
|
| 482 |
+
for p in self.model.parameters():
|
| 483 |
+
p.grad = None
|
| 484 |
+
|
| 485 |
+
@torch.no_grad()
|
| 486 |
+
def step(self) -> None:
|
| 487 |
+
for p in self.model.parameters():
|
| 488 |
+
if p.grad is None:
|
| 489 |
+
continue
|
| 490 |
+
if self.masker is not None and self.freeze_non_linear_when_sparse and p not in self.linear_param:
|
| 491 |
+
# Optional stricter mode: freeze embeddings/layernorm/etc. in sparse runs.
|
| 492 |
+
continue
|
| 493 |
+
|
| 494 |
+
if p not in self.state:
|
| 495 |
+
self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
|
| 496 |
+
m = self.state[p]["m"]
|
| 497 |
+
v = self.state[p]["v"]
|
| 498 |
+
g = p.grad
|
| 499 |
+
if self.weight_decay:
|
| 500 |
+
g = g.add(p, alpha=self.weight_decay)
|
| 501 |
+
|
| 502 |
+
row_mask = None
|
| 503 |
+
if self.masker is not None and p in self.linear_param:
|
| 504 |
+
module, kind = self.linear_param[p]
|
| 505 |
+
base = self.masker.row_mask_for(module)
|
| 506 |
+
if base is not None:
|
| 507 |
+
row_mask = base.view(-1, *([1] * (p.ndim - 1))) if kind == "weight" else base
|
| 508 |
+
|
| 509 |
+
if row_mask is None:
|
| 510 |
+
m.mul_(self.beta1).add_(g, alpha=1.0 - self.beta1)
|
| 511 |
+
v.mul_(self.beta2).addcmul_(g, g, value=1.0 - self.beta2)
|
| 512 |
+
p.add_(m / (torch.sqrt(v) + self.eps), alpha=-self.lr)
|
| 513 |
+
else:
|
| 514 |
+
mask = row_mask.expand_as(p)
|
| 515 |
+
if not bool(mask.any().item()):
|
| 516 |
+
continue
|
| 517 |
+
new_m = self.beta1 * m + (1.0 - self.beta1) * g
|
| 518 |
+
new_v = self.beta2 * v + (1.0 - self.beta2) * g * g
|
| 519 |
+
m[mask] = new_m[mask]
|
| 520 |
+
v[mask] = new_v[mask]
|
| 521 |
+
update = m / (torch.sqrt(v) + self.eps)
|
| 522 |
+
p[mask] = p[mask] - self.lr * update[mask]
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
# -----------------------------
|
| 526 |
+
# Training
|
| 527 |
+
# -----------------------------
|
| 528 |
+
|
| 529 |
+
@torch.no_grad()
|
| 530 |
+
def estimate_loss(model: nn.Module, corpus: CharCorpus, batch_size: int, eval_iters: int) -> Dict[str, float]:
|
| 531 |
+
model.eval()
|
| 532 |
+
out = {}
|
| 533 |
+
for split in ["train", "val"]:
|
| 534 |
+
losses = []
|
| 535 |
+
for _ in range(eval_iters):
|
| 536 |
+
x, y = corpus.get_batch(split, batch_size)
|
| 537 |
+
_, loss = model(x, y)
|
| 538 |
+
losses.append(float(loss.item()))
|
| 539 |
+
out[split] = sum(losses) / len(losses)
|
| 540 |
+
model.train()
|
| 541 |
+
return out
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
def train_run(
|
| 545 |
+
corpus: CharCorpus,
|
| 546 |
+
args: argparse.Namespace,
|
| 547 |
+
policy: Optional[Policy],
|
| 548 |
+
active_fraction: float,
|
| 549 |
+
warmup_steps: int,
|
| 550 |
+
explore_fraction: float,
|
| 551 |
+
seed_offset: int,
|
| 552 |
+
) -> Dict[str, float | str]:
|
| 553 |
+
set_seed(args.seed + seed_offset)
|
| 554 |
+
dev = corpus.device
|
| 555 |
+
model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev)
|
| 556 |
+
|
| 557 |
+
masker = None
|
| 558 |
+
if policy is not None:
|
| 559 |
+
masker = RowMasker(
|
| 560 |
+
model=model,
|
| 561 |
+
policy=policy,
|
| 562 |
+
active_fraction=active_fraction,
|
| 563 |
+
explore_fraction=explore_fraction,
|
| 564 |
+
mass_beta=args.mass_beta,
|
| 565 |
+
unobserved_decay=args.unobserved_decay,
|
| 566 |
+
warmup_steps=warmup_steps,
|
| 567 |
+
ucb_alpha=args.ucb_alpha,
|
| 568 |
+
mass_init=args.mass_init,
|
| 569 |
+
device=dev,
|
| 570 |
+
)
|
| 571 |
+
opt = MaskedAdam(
|
| 572 |
+
model,
|
| 573 |
+
masker,
|
| 574 |
+
lr=args.lr,
|
| 575 |
+
weight_decay=args.weight_decay,
|
| 576 |
+
freeze_non_linear_when_sparse=args.freeze_non_linear_when_sparse,
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
sums = {
|
| 580 |
+
"cosine": 0.0,
|
| 581 |
+
"norm_ratio": 0.0,
|
| 582 |
+
"top20_mass": 0.0,
|
| 583 |
+
"jacc_oracle": 0.0,
|
| 584 |
+
"stability": 0.0,
|
| 585 |
+
"active_fraction_real": 0.0,
|
| 586 |
+
"coverage": 0.0,
|
| 587 |
+
"avg_obs_count": 0.0,
|
| 588 |
+
"new_active_fraction": 0.0,
|
| 589 |
+
}
|
| 590 |
+
count = 0
|
| 591 |
+
|
| 592 |
+
for step in range(args.steps):
|
| 593 |
+
x, y = corpus.get_batch("train", args.batch_size)
|
| 594 |
+
if masker is not None:
|
| 595 |
+
masker.choose_pre_backward(step)
|
| 596 |
+
_, loss = model(x, y)
|
| 597 |
+
opt.zero_grad()
|
| 598 |
+
loss.backward()
|
| 599 |
+
if masker is not None:
|
| 600 |
+
metrics = masker.audit_and_update(step)
|
| 601 |
+
if step >= warmup_steps:
|
| 602 |
+
for k in sums:
|
| 603 |
+
sums[k] += metrics[k]
|
| 604 |
+
count += 1
|
| 605 |
+
opt.step()
|
| 606 |
+
|
| 607 |
+
if args.verbose and (step % args.eval_interval == 0 or step == args.steps - 1):
|
| 608 |
+
losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters)
|
| 609 |
+
name = "dense" if policy is None else policy
|
| 610 |
+
print(
|
| 611 |
+
f"{name:20s} step={step:5d} warm={warmup_steps:4d} explore={explore_fraction:.2f} "
|
| 612 |
+
f"train={losses['train']:.4f} val={losses['val']:.4f}"
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters)
|
| 616 |
+
row: Dict[str, float | str] = {
|
| 617 |
+
"run": "dense_baseline" if policy is None else policy,
|
| 618 |
+
"target_active": 1.0 if policy is None else active_fraction,
|
| 619 |
+
"warmup": warmup_steps,
|
| 620 |
+
"explore": explore_fraction if policy is not None else 0.0,
|
| 621 |
+
"train_loss": losses["train"],
|
| 622 |
+
"val_loss": losses["val"],
|
| 623 |
+
}
|
| 624 |
+
if masker is None or count == 0:
|
| 625 |
+
row.update({
|
| 626 |
+
"cosine": float("nan"),
|
| 627 |
+
"norm_ratio": float("nan"),
|
| 628 |
+
"top20_mass": float("nan"),
|
| 629 |
+
"jacc_oracle": float("nan"),
|
| 630 |
+
"stability": float("nan"),
|
| 631 |
+
"active_fraction_real": 1.0,
|
| 632 |
+
"coverage": float("nan"),
|
| 633 |
+
"avg_obs_count": float("nan"),
|
| 634 |
+
"new_active_fraction": float("nan"),
|
| 635 |
+
})
|
| 636 |
+
else:
|
| 637 |
+
for k, v in sums.items():
|
| 638 |
+
row[k] = v / count
|
| 639 |
+
return row
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
def print_summary(rows: List[Dict[str, float | str]]) -> None:
|
| 643 |
+
print("\nSummary")
|
| 644 |
+
header = (
|
| 645 |
+
f"{'run':>22s} {'target':>7s} {'actual':>7s} {'warm':>5s} {'expl':>5s} "
|
| 646 |
+
f"{'val':>8s} {'train':>8s} {'cos':>7s} {'top20':>7s} {'jacc':>7s} "
|
| 647 |
+
f"{'stable':>7s} {'cover':>7s} {'new':>7s}"
|
| 648 |
+
)
|
| 649 |
+
print(header)
|
| 650 |
+
print("-" * len(header))
|
| 651 |
+
for r in rows:
|
| 652 |
+
print(
|
| 653 |
+
f"{str(r['run']):>22s} "
|
| 654 |
+
f"{float(r['target_active']):7.3f} "
|
| 655 |
+
f"{float(r['active_fraction_real']):7.3f} "
|
| 656 |
+
f"{int(float(r['warmup'])):5d} "
|
| 657 |
+
f"{float(r['explore']):5.2f} "
|
| 658 |
+
f"{float(r['val_loss']):8.4f} "
|
| 659 |
+
f"{float(r['train_loss']):8.4f} "
|
| 660 |
+
f"{float(r['cosine']):7.3f} "
|
| 661 |
+
f"{float(r['top20_mass']):7.3f} "
|
| 662 |
+
f"{float(r['jacc_oracle']):7.3f} "
|
| 663 |
+
f"{float(r['stability']):7.3f} "
|
| 664 |
+
f"{float(r['coverage']):7.3f} "
|
| 665 |
+
f"{float(r['new_active_fraction']):7.3f}"
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
def parse_args() -> argparse.Namespace:
|
| 670 |
+
p = argparse.ArgumentParser()
|
| 671 |
+
p.add_argument("--text_path", type=str, default=None)
|
| 672 |
+
p.add_argument("--synthetic_sentences", type=int, default=12000)
|
| 673 |
+
p.add_argument("--steps", type=int, default=1000)
|
| 674 |
+
p.add_argument("--quick", action="store_true")
|
| 675 |
+
p.add_argument("--batch_size", type=int, default=32)
|
| 676 |
+
p.add_argument("--block_size", type=int, default=64)
|
| 677 |
+
p.add_argument("--n_layer", type=int, default=2)
|
| 678 |
+
p.add_argument("--n_head", type=int, default=4)
|
| 679 |
+
p.add_argument("--n_embd", type=int, default=64)
|
| 680 |
+
p.add_argument("--dropout", type=float, default=0.0)
|
| 681 |
+
p.add_argument("--lr", type=float, default=3e-4)
|
| 682 |
+
p.add_argument("--weight_decay", type=float, default=0.0)
|
| 683 |
+
p.add_argument("--active_fractions", type=float, nargs="+", default=[0.10, 0.05, 0.02])
|
| 684 |
+
p.add_argument("--policies", type=str, nargs="+", default=["oracle_current", "predicted_magnitude", "ucb_magnitude", "random"])
|
| 685 |
+
p.add_argument("--explore_fractions", type=float, nargs="+", default=[0.10])
|
| 686 |
+
p.add_argument("--warmup_steps_list", type=int, nargs="+", default=[5])
|
| 687 |
+
p.add_argument("--mass_beta", type=float, default=0.95)
|
| 688 |
+
p.add_argument("--unobserved_decay", type=float, default=1.0)
|
| 689 |
+
p.add_argument("--mass_init", type=float, default=0.0)
|
| 690 |
+
p.add_argument("--ucb_alpha", type=float, default=1.0)
|
| 691 |
+
p.add_argument("--freeze_non_linear_when_sparse", action="store_true")
|
| 692 |
+
p.add_argument("--eval_interval", type=int, default=200)
|
| 693 |
+
p.add_argument("--eval_iters", type=int, default=20)
|
| 694 |
+
p.add_argument("--seed", type=int, default=7)
|
| 695 |
+
p.add_argument("--verbose", action="store_true")
|
| 696 |
+
return p.parse_args()
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
def main() -> None:
|
| 700 |
+
args = parse_args()
|
| 701 |
+
if args.quick:
|
| 702 |
+
args.steps = 60
|
| 703 |
+
args.eval_iters = 3
|
| 704 |
+
args.batch_size = 16
|
| 705 |
+
args.block_size = 32
|
| 706 |
+
args.n_layer = 1
|
| 707 |
+
args.n_embd = 32
|
| 708 |
+
args.n_head = 4
|
| 709 |
+
args.synthetic_sentences = 2000
|
| 710 |
+
args.active_fractions = [0.10, 0.02]
|
| 711 |
+
args.policies = ["oracle_current", "predicted_magnitude", "ucb_magnitude", "random"]
|
| 712 |
+
args.explore_fractions = [0.10]
|
| 713 |
+
args.warmup_steps_list = [0]
|
| 714 |
+
|
| 715 |
+
# Validate policy strings early.
|
| 716 |
+
valid = {"predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"}
|
| 717 |
+
for pol in args.policies:
|
| 718 |
+
if pol not in valid:
|
| 719 |
+
raise ValueError(f"Unknown policy {pol!r}. Valid policies: {sorted(valid)}")
|
| 720 |
+
|
| 721 |
+
set_seed(args.seed)
|
| 722 |
+
dev = device()
|
| 723 |
+
print(f"device={dev}")
|
| 724 |
+
corpus = CharCorpus(load_text(args), args.block_size, dev)
|
| 725 |
+
print(f"vocab_size={corpus.vocab_size} train_tokens={len(corpus.train_data)} val_tokens={len(corpus.val_data)}")
|
| 726 |
+
print(f"policies={args.policies}")
|
| 727 |
+
print(f"active_fractions={args.active_fractions}")
|
| 728 |
+
print(f"warmup_steps_list={args.warmup_steps_list} explore_fractions={args.explore_fractions}")
|
| 729 |
+
print(f"mass_init={args.mass_init} mass_beta={args.mass_beta} ucb_alpha={args.ucb_alpha}")
|
| 730 |
+
|
| 731 |
+
# Report how much of the model is governed by row masks.
|
| 732 |
+
tmp_model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev)
|
| 733 |
+
total_params, linear_params, linear_frac = parameter_fractions(tmp_model)
|
| 734 |
+
del tmp_model
|
| 735 |
+
print(f"params total={total_params} linear={linear_params} linear_fraction={linear_frac:.3f}")
|
| 736 |
+
if args.freeze_non_linear_when_sparse:
|
| 737 |
+
print("freeze_non_linear_when_sparse=True: embeddings/layernorm/etc. are frozen in sparse runs")
|
| 738 |
+
else:
|
| 739 |
+
print("freeze_non_linear_when_sparse=False: non-Linear params are still updated densely")
|
| 740 |
+
|
| 741 |
+
rows: List[Dict[str, float | str]] = []
|
| 742 |
+
print("\nRunning dense baseline")
|
| 743 |
+
rows.append(train_run(corpus, args, policy=None, active_fraction=1.0, warmup_steps=0, explore_fraction=0.0, seed_offset=0))
|
| 744 |
+
|
| 745 |
+
seed_offset = 100
|
| 746 |
+
for af in args.active_fractions:
|
| 747 |
+
for pol in args.policies:
|
| 748 |
+
# oracle_current and stale_current do not use explore_fraction; random does not either.
|
| 749 |
+
explore_values = args.explore_fractions if pol in {"predicted_magnitude", "ucb_magnitude"} else [0.0]
|
| 750 |
+
# Warmup matters for every sparse policy, so keep it in the loop.
|
| 751 |
+
for warmup in args.warmup_steps_list:
|
| 752 |
+
for explore in explore_values:
|
| 753 |
+
print(f"\nRunning policy={pol}, active_fraction={af:.3f}, warmup={warmup}, explore={explore:.2f}")
|
| 754 |
+
rows.append(
|
| 755 |
+
train_run(
|
| 756 |
+
corpus,
|
| 757 |
+
args,
|
| 758 |
+
policy=pol, # type: ignore[arg-type]
|
| 759 |
+
active_fraction=af,
|
| 760 |
+
warmup_steps=warmup,
|
| 761 |
+
explore_fraction=explore,
|
| 762 |
+
seed_offset=seed_offset,
|
| 763 |
+
)
|
| 764 |
+
)
|
| 765 |
+
seed_offset += 1
|
| 766 |
+
|
| 767 |
+
print_summary(rows)
|
| 768 |
+
|
| 769 |
+
print("\nNotes")
|
| 770 |
+
print(" oracle_current uses current dense gradients to choose rows; it is the true upper bound.")
|
| 771 |
+
print(" stale_current uses previous-step dense gradient mass; it is a renamed stale/noisy control.")
|
| 772 |
+
print(" predicted_magnitude uses only EMA mass from active/observed rows.")
|
| 773 |
+
print(" ucb_magnitude adds an uncertainty bonus for under-observed rows to improve discovery.")
|
| 774 |
+
print(" coverage is the fraction of Linear rows that have ever been observed/active.")
|
| 775 |
+
print(" new is the average fraction of rows newly observed per non-warmup step.")
|
| 776 |
+
print(" dense gradients are still computed for audit; this is not a wall-clock benchmark yet.")
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
if __name__ == "__main__":
|
| 780 |
+
main()
|
experiments/sparse_transformer_v8.py
ADDED
|
@@ -0,0 +1,943 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sparse Transformer v8: from masked-optimizer simulation to real sparse Linear backward.
|
| 3 |
+
|
| 4 |
+
v7 showed that Transformer Linear-row gradient support is heavy-tailed and stable,
|
| 5 |
+
and that a practical EMA selector can nearly match an oracle selector after a tiny
|
| 6 |
+
warmup. But v7 still computed dense gradients and only masked the optimizer step.
|
| 7 |
+
|
| 8 |
+
v8 tests the next question:
|
| 9 |
+
|
| 10 |
+
Can the sparse row mask be moved into the Linear backward pass itself?
|
| 11 |
+
|
| 12 |
+
Backward modes
|
| 13 |
+
--------------
|
| 14 |
+
1. masked_optimizer
|
| 15 |
+
v7-style control. Compute dense backward, but MaskedAdam only updates active
|
| 16 |
+
Linear rows. This should match the previous simulation behavior.
|
| 17 |
+
|
| 18 |
+
2. sparse_dW_full_dX
|
| 19 |
+
Custom autograd Linear computes grad_weight / grad_bias only for active output
|
| 20 |
+
rows, while still propagating full grad_input backward. This is the conservative
|
| 21 |
+
real-backward mode. It targets the dW part of Linear backward only.
|
| 22 |
+
|
| 23 |
+
3. sparse_dW_sparse_dX
|
| 24 |
+
Custom autograd Linear computes grad_weight only for active rows and also
|
| 25 |
+
propagates grad_input only through active output rows. This is the aggressive
|
| 26 |
+
mode. It may save more backward compute in a real kernel, but it can damage
|
| 27 |
+
upstream learning.
|
| 28 |
+
|
| 29 |
+
Important caveat
|
| 30 |
+
----------------
|
| 31 |
+
This script still performs a dense audit backward pass each training step to:
|
| 32 |
+
- compute oracle metrics,
|
| 33 |
+
- support oracle_current and stale_current controls,
|
| 34 |
+
- update practical EMA statistics only for active/observed rows.
|
| 35 |
+
|
| 36 |
+
The actual training update in sparse_dW_* modes comes from the custom sparse
|
| 37 |
+
backward pass, not from the dense audit gradients. This is a correctness and
|
| 38 |
+
semantics experiment, not a wall-clock benchmark.
|
| 39 |
+
|
| 40 |
+
Example
|
| 41 |
+
-------
|
| 42 |
+
Smoke test:
|
| 43 |
+
python3 sparse_transformer_v8.py --quick
|
| 44 |
+
|
| 45 |
+
Main comparison:
|
| 46 |
+
python3 sparse_transformer_v8.py \
|
| 47 |
+
--steps 2000 \
|
| 48 |
+
--active_fractions 0.05 0.02 \
|
| 49 |
+
--warmup_steps_list 5 \
|
| 50 |
+
--explore_fractions 0.00 \
|
| 51 |
+
--policies oracle_current predicted_magnitude random \
|
| 52 |
+
--backward_modes masked_optimizer sparse_dW_full_dX sparse_dW_sparse_dX
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
from __future__ import annotations
|
| 56 |
+
|
| 57 |
+
import argparse
|
| 58 |
+
import math
|
| 59 |
+
import random
|
| 60 |
+
from typing import Dict, List, Literal, Optional, Tuple
|
| 61 |
+
|
| 62 |
+
import torch
|
| 63 |
+
|
| 64 |
+
torch.set_num_threads(1)
|
| 65 |
+
import torch.nn as nn
|
| 66 |
+
import torch.nn.functional as F
|
| 67 |
+
|
| 68 |
+
Policy = Literal["predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"]
|
| 69 |
+
BackwardMode = Literal["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# -----------------------------
|
| 73 |
+
# Reproducibility and device
|
| 74 |
+
# -----------------------------
|
| 75 |
+
|
| 76 |
+
def set_seed(seed: int) -> None:
|
| 77 |
+
random.seed(seed)
|
| 78 |
+
torch.manual_seed(seed)
|
| 79 |
+
if torch.cuda.is_available():
|
| 80 |
+
torch.cuda.manual_seed_all(seed)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def device() -> str:
|
| 84 |
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def make_cpu_generator(seed: int) -> torch.Generator:
|
| 88 |
+
gen = torch.Generator(device="cpu")
|
| 89 |
+
gen.manual_seed(seed)
|
| 90 |
+
return gen
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# -----------------------------
|
| 94 |
+
# Data
|
| 95 |
+
# -----------------------------
|
| 96 |
+
|
| 97 |
+
def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
|
| 98 |
+
rng = random.Random(seed)
|
| 99 |
+
names = ["ada", "turing", "grace", "lovelace", "noether", "shannon", "hopper", "gauss"]
|
| 100 |
+
verbs = ["builds", "tests", "traces", "compresses", "predicts", "routes", "writes", "measures"]
|
| 101 |
+
objects = ["signals", "gradients", "tokens", "circuits", "features", "masks", "errors", "states"]
|
| 102 |
+
adverbs = ["quietly", "boldly", "slowly", "quickly", "cleanly", "strangely", "carefully"]
|
| 103 |
+
clauses = [
|
| 104 |
+
"when the loss falls",
|
| 105 |
+
"after the mask shifts",
|
| 106 |
+
"before the model answers",
|
| 107 |
+
"while the signal drifts",
|
| 108 |
+
"if the pattern repeats",
|
| 109 |
+
"because the tail is noisy",
|
| 110 |
+
]
|
| 111 |
+
symbols = ["alpha", "beta", "gamma", "delta", "omega", "sigma"]
|
| 112 |
+
|
| 113 |
+
lines: List[str] = []
|
| 114 |
+
for _ in range(n_sentences):
|
| 115 |
+
t = rng.randrange(6)
|
| 116 |
+
if t == 0:
|
| 117 |
+
line = f"{rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} {rng.choice(adverbs)}."
|
| 118 |
+
elif t == 1:
|
| 119 |
+
line = f"{rng.choice(clauses)}, {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)}."
|
| 120 |
+
elif t == 2:
|
| 121 |
+
a, b = rng.sample(symbols, 2)
|
| 122 |
+
line = f"rule {a}: {rng.choice(objects)} -> {rng.choice(objects)}; rule {b}: {rng.choice(objects)} -> {rng.choice(objects)}."
|
| 123 |
+
elif t == 3:
|
| 124 |
+
line = f"the {rng.choice(objects)} {rng.choice(verbs)} the {rng.choice(objects)} {rng.choice(adverbs)}."
|
| 125 |
+
elif t == 4:
|
| 126 |
+
seq = " ".join(rng.choice(symbols) for _ in range(rng.randint(2, 7)))
|
| 127 |
+
line = f"sequence {seq} ends when {rng.choice(names)} {rng.choice(verbs)}."
|
| 128 |
+
else:
|
| 129 |
+
line = f"if {rng.choice(objects)} rise then {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} else wait."
|
| 130 |
+
lines.append(line)
|
| 131 |
+
return "\n".join(lines) + "\n"
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class CharCorpus:
|
| 135 |
+
def __init__(self, text: str, block_size: int, device: str):
|
| 136 |
+
chars = sorted(set(text))
|
| 137 |
+
self.stoi = {ch: i for i, ch in enumerate(chars)}
|
| 138 |
+
self.itos = {i: ch for ch, i in self.stoi.items()}
|
| 139 |
+
self.vocab_size = len(chars)
|
| 140 |
+
self.block_size = block_size
|
| 141 |
+
self.device = device
|
| 142 |
+
|
| 143 |
+
data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
|
| 144 |
+
split = int(0.9 * len(data))
|
| 145 |
+
self.train_data = data[:split]
|
| 146 |
+
self.val_data = data[split:]
|
| 147 |
+
|
| 148 |
+
def get_batch(
|
| 149 |
+
self,
|
| 150 |
+
split: str,
|
| 151 |
+
batch_size: int,
|
| 152 |
+
generator: Optional[torch.Generator] = None,
|
| 153 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 154 |
+
data = self.train_data if split == "train" else self.val_data
|
| 155 |
+
max_start = len(data) - self.block_size - 1
|
| 156 |
+
if max_start <= 0:
|
| 157 |
+
raise ValueError("Corpus too small for block_size")
|
| 158 |
+
ix = torch.randint(max_start, (batch_size,), generator=generator)
|
| 159 |
+
x = torch.stack([data[i : i + self.block_size] for i in ix])
|
| 160 |
+
y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
|
| 161 |
+
return x.to(self.device), y.to(self.device)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def load_text(args: argparse.Namespace) -> str:
|
| 165 |
+
if args.text_path:
|
| 166 |
+
with open(args.text_path, "r", encoding="utf-8") as f:
|
| 167 |
+
return f.read()
|
| 168 |
+
return make_synthetic_corpus(args.synthetic_sentences, args.seed)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# -----------------------------
|
| 172 |
+
# Sparse Linear autograd
|
| 173 |
+
# -----------------------------
|
| 174 |
+
|
| 175 |
+
class MaskedLinearFunction(torch.autograd.Function):
|
| 176 |
+
@staticmethod
|
| 177 |
+
def forward( # type: ignore[override]
|
| 178 |
+
ctx,
|
| 179 |
+
x: torch.Tensor,
|
| 180 |
+
weight: torch.Tensor,
|
| 181 |
+
bias: Optional[torch.Tensor],
|
| 182 |
+
active_rows: torch.Tensor,
|
| 183 |
+
sparse_dx: bool,
|
| 184 |
+
) -> torch.Tensor:
|
| 185 |
+
ctx.save_for_backward(x, weight, active_rows)
|
| 186 |
+
ctx.has_bias = bias is not None
|
| 187 |
+
ctx.sparse_dx = bool(sparse_dx)
|
| 188 |
+
return F.linear(x, weight, bias)
|
| 189 |
+
|
| 190 |
+
@staticmethod
|
| 191 |
+
def backward(ctx, grad_y: torch.Tensor): # type: ignore[override]
|
| 192 |
+
x, weight, active_rows = ctx.saved_tensors
|
| 193 |
+
sparse_dx = bool(ctx.sparse_dx)
|
| 194 |
+
has_bias = bool(ctx.has_bias)
|
| 195 |
+
|
| 196 |
+
x_shape = x.shape
|
| 197 |
+
x_flat = x.reshape(-1, x.shape[-1])
|
| 198 |
+
gy_flat = grad_y.reshape(-1, grad_y.shape[-1])
|
| 199 |
+
|
| 200 |
+
active_idx = torch.nonzero(active_rows, as_tuple=False).flatten()
|
| 201 |
+
|
| 202 |
+
grad_weight = torch.zeros_like(weight)
|
| 203 |
+
grad_bias = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if has_bias else None
|
| 204 |
+
|
| 205 |
+
if active_idx.numel() > 0:
|
| 206 |
+
gy_active = gy_flat[:, active_idx]
|
| 207 |
+
grad_weight[active_idx] = gy_active.transpose(0, 1) @ x_flat
|
| 208 |
+
if grad_bias is not None:
|
| 209 |
+
grad_bias[active_idx] = gy_active.sum(dim=0)
|
| 210 |
+
|
| 211 |
+
if sparse_dx:
|
| 212 |
+
grad_x_flat = gy_active @ weight[active_idx]
|
| 213 |
+
else:
|
| 214 |
+
grad_x_flat = gy_flat @ weight
|
| 215 |
+
else:
|
| 216 |
+
# This can happen when a global top-k mask selects no rows from a
|
| 217 |
+
# particular layer. Conservative full_dX still propagates through that
|
| 218 |
+
# layer; aggressive sparse_dX cuts it off for that layer.
|
| 219 |
+
if sparse_dx:
|
| 220 |
+
grad_x_flat = torch.zeros_like(x_flat)
|
| 221 |
+
else:
|
| 222 |
+
grad_x_flat = gy_flat @ weight
|
| 223 |
+
|
| 224 |
+
grad_x = grad_x_flat.reshape(x_shape)
|
| 225 |
+
return grad_x, grad_weight, grad_bias, None, None
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class SparseLinear(nn.Linear):
|
| 229 |
+
"""nn.Linear with an optional row-sparse backward pass."""
|
| 230 |
+
|
| 231 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
| 232 |
+
super().__init__(in_features, out_features, bias=bias)
|
| 233 |
+
self.sparse_enabled = False
|
| 234 |
+
self.sparse_dx = False
|
| 235 |
+
self.active_rows: Optional[torch.Tensor] = None
|
| 236 |
+
|
| 237 |
+
def set_sparse_backward(self, enabled: bool, active_rows: Optional[torch.Tensor], sparse_dx: bool) -> None:
|
| 238 |
+
self.sparse_enabled = bool(enabled)
|
| 239 |
+
self.sparse_dx = bool(sparse_dx)
|
| 240 |
+
self.active_rows = active_rows
|
| 241 |
+
|
| 242 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 243 |
+
if not self.sparse_enabled or self.active_rows is None:
|
| 244 |
+
return F.linear(x, self.weight, self.bias)
|
| 245 |
+
return MaskedLinearFunction.apply(x, self.weight, self.bias, self.active_rows, self.sparse_dx)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# -----------------------------
|
| 249 |
+
# Mini GPT
|
| 250 |
+
# -----------------------------
|
| 251 |
+
|
| 252 |
+
class CausalSelfAttention(nn.Module):
|
| 253 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 254 |
+
super().__init__()
|
| 255 |
+
assert n_embd % n_head == 0
|
| 256 |
+
self.n_head = n_head
|
| 257 |
+
self.head_dim = n_embd // n_head
|
| 258 |
+
self.c_attn = SparseLinear(n_embd, 3 * n_embd)
|
| 259 |
+
self.c_proj = SparseLinear(n_embd, n_embd)
|
| 260 |
+
self.dropout = nn.Dropout(dropout)
|
| 261 |
+
self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
|
| 262 |
+
|
| 263 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 264 |
+
B, T, C = x.shape
|
| 265 |
+
qkv = self.c_attn(x)
|
| 266 |
+
q, k, v = qkv.split(C, dim=2)
|
| 267 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 268 |
+
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 269 |
+
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 270 |
+
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 271 |
+
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
|
| 272 |
+
att = F.softmax(att, dim=-1)
|
| 273 |
+
att = self.dropout(att)
|
| 274 |
+
y = att @ v
|
| 275 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 276 |
+
return self.c_proj(y)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class FeedForward(nn.Module):
|
| 280 |
+
def __init__(self, n_embd: int, dropout: float):
|
| 281 |
+
super().__init__()
|
| 282 |
+
self.c_fc = SparseLinear(n_embd, 4 * n_embd)
|
| 283 |
+
self.c_proj = SparseLinear(4 * n_embd, n_embd)
|
| 284 |
+
self.dropout = nn.Dropout(dropout)
|
| 285 |
+
|
| 286 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 287 |
+
return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class Block(nn.Module):
|
| 291 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 292 |
+
super().__init__()
|
| 293 |
+
self.ln1 = nn.LayerNorm(n_embd)
|
| 294 |
+
self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
|
| 295 |
+
self.ln2 = nn.LayerNorm(n_embd)
|
| 296 |
+
self.mlp = FeedForward(n_embd, dropout)
|
| 297 |
+
|
| 298 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 299 |
+
x = x + self.attn(self.ln1(x))
|
| 300 |
+
x = x + self.mlp(self.ln2(x))
|
| 301 |
+
return x
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class MiniGPT(nn.Module):
|
| 305 |
+
def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float):
|
| 306 |
+
super().__init__()
|
| 307 |
+
self.block_size = block_size
|
| 308 |
+
self.tok_emb = nn.Embedding(vocab_size, n_embd)
|
| 309 |
+
self.pos_emb = nn.Embedding(block_size, n_embd)
|
| 310 |
+
self.drop = nn.Dropout(dropout)
|
| 311 |
+
self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
|
| 312 |
+
self.ln_f = nn.LayerNorm(n_embd)
|
| 313 |
+
self.lm_head = SparseLinear(n_embd, vocab_size)
|
| 314 |
+
|
| 315 |
+
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
|
| 316 |
+
B, T = idx.shape
|
| 317 |
+
pos = torch.arange(T, device=idx.device)
|
| 318 |
+
x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
|
| 319 |
+
x = self.drop(x)
|
| 320 |
+
x = self.blocks(x)
|
| 321 |
+
x = self.ln_f(x)
|
| 322 |
+
logits = self.lm_head(x)
|
| 323 |
+
loss = None
|
| 324 |
+
if targets is not None:
|
| 325 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
| 326 |
+
return logits, loss
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def named_sparse_linear_modules(model: nn.Module) -> List[Tuple[str, SparseLinear]]:
|
| 330 |
+
return [(name, m) for name, m in model.named_modules() if isinstance(m, SparseLinear)]
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def parameter_fractions(model: nn.Module) -> Tuple[int, int, float]:
|
| 334 |
+
total = sum(p.numel() for p in model.parameters())
|
| 335 |
+
linear = 0
|
| 336 |
+
for _, m in named_sparse_linear_modules(model):
|
| 337 |
+
linear += m.weight.numel()
|
| 338 |
+
if m.bias is not None:
|
| 339 |
+
linear += m.bias.numel()
|
| 340 |
+
return total, linear, linear / max(1, total)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def configure_sparse_linears(
|
| 344 |
+
model: nn.Module,
|
| 345 |
+
masker: Optional["RowMasker"],
|
| 346 |
+
enabled: bool,
|
| 347 |
+
backward_mode: Optional[str],
|
| 348 |
+
) -> None:
|
| 349 |
+
sparse_dx = backward_mode == "sparse_dW_sparse_dX"
|
| 350 |
+
for _, m in named_sparse_linear_modules(model):
|
| 351 |
+
active = masker.row_mask_for(m) if masker is not None else None
|
| 352 |
+
m.set_sparse_backward(enabled=enabled, active_rows=active, sparse_dx=sparse_dx)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
# -----------------------------
|
| 356 |
+
# Mask selector
|
| 357 |
+
# -----------------------------
|
| 358 |
+
|
| 359 |
+
class RowMasker:
|
| 360 |
+
def __init__(
|
| 361 |
+
self,
|
| 362 |
+
model: nn.Module,
|
| 363 |
+
policy: Policy,
|
| 364 |
+
active_fraction: float,
|
| 365 |
+
explore_fraction: float,
|
| 366 |
+
mass_beta: float,
|
| 367 |
+
unobserved_decay: float,
|
| 368 |
+
warmup_steps: int,
|
| 369 |
+
ucb_alpha: float,
|
| 370 |
+
mass_init: float,
|
| 371 |
+
device: str,
|
| 372 |
+
):
|
| 373 |
+
self.model = model
|
| 374 |
+
self.policy = policy
|
| 375 |
+
self.active_fraction = active_fraction
|
| 376 |
+
self.explore_fraction = explore_fraction
|
| 377 |
+
self.mass_beta = mass_beta
|
| 378 |
+
self.unobserved_decay = unobserved_decay
|
| 379 |
+
self.warmup_steps = warmup_steps
|
| 380 |
+
self.ucb_alpha = ucb_alpha
|
| 381 |
+
self.mass_init = mass_init
|
| 382 |
+
self.device = device
|
| 383 |
+
self.step_index = 0
|
| 384 |
+
|
| 385 |
+
self.linear_modules = [m for _, m in named_sparse_linear_modules(model)]
|
| 386 |
+
self.module_to_ids: Dict[SparseLinear, torch.Tensor] = {}
|
| 387 |
+
ids = []
|
| 388 |
+
offset = 0
|
| 389 |
+
for m in self.linear_modules:
|
| 390 |
+
n = m.weight.shape[0]
|
| 391 |
+
block_ids = torch.arange(offset, offset + n, device=device)
|
| 392 |
+
self.module_to_ids[m] = block_ids
|
| 393 |
+
ids.append(block_ids)
|
| 394 |
+
offset += n
|
| 395 |
+
self.n_blocks = offset
|
| 396 |
+
|
| 397 |
+
self.predicted_mass = torch.full((self.n_blocks,), mass_init, device=device)
|
| 398 |
+
self.last_full_mass = torch.full((self.n_blocks,), mass_init, device=device)
|
| 399 |
+
self.observed_count = torch.zeros(self.n_blocks, device=device)
|
| 400 |
+
self.global_mass_ema = torch.tensor(max(mass_init, 1e-6), device=device)
|
| 401 |
+
|
| 402 |
+
self.prev_active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device)
|
| 403 |
+
self.active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device)
|
| 404 |
+
self.row_masks: Dict[SparseLinear, torch.Tensor] = {
|
| 405 |
+
m: torch.zeros(m.weight.shape[0], dtype=torch.bool, device=device) for m in self.linear_modules
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
def _topk_mask(self, values: torch.Tensor, fraction: float) -> torch.Tensor:
|
| 409 |
+
k = max(1, int(fraction * values.numel()))
|
| 410 |
+
mask = torch.zeros_like(values, dtype=torch.bool)
|
| 411 |
+
noisy = values + 1e-9 * torch.rand_like(values)
|
| 412 |
+
mask[torch.topk(noisy, k=k).indices] = True
|
| 413 |
+
return mask
|
| 414 |
+
|
| 415 |
+
@staticmethod
|
| 416 |
+
def _jaccard(a: torch.Tensor, b: torch.Tensor) -> float:
|
| 417 |
+
inter = (a & b).sum().float()
|
| 418 |
+
union = (a | b).sum().float()
|
| 419 |
+
return float((inter / torch.clamp(union, min=1.0)).item())
|
| 420 |
+
|
| 421 |
+
def _set_active(self, active: torch.Tensor) -> None:
|
| 422 |
+
self.active = active
|
| 423 |
+
self.row_masks = {}
|
| 424 |
+
for m, ids in self.module_to_ids.items():
|
| 425 |
+
self.row_masks[m] = active[ids]
|
| 426 |
+
|
| 427 |
+
def _sample_exploit_explore(self, scores: torch.Tensor) -> torch.Tensor:
|
| 428 |
+
n = self.n_blocks
|
| 429 |
+
k_total = max(1, int(self.active_fraction * n))
|
| 430 |
+
k_explore = min(k_total, max(0, int(self.explore_fraction * k_total)))
|
| 431 |
+
k_exploit = k_total - k_explore
|
| 432 |
+
active = torch.zeros(n, dtype=torch.bool, device=self.device)
|
| 433 |
+
|
| 434 |
+
if k_exploit > 0:
|
| 435 |
+
active[torch.topk(scores + 1e-9 * torch.rand_like(scores), k=k_exploit).indices] = True
|
| 436 |
+
if k_explore > 0:
|
| 437 |
+
remaining = torch.nonzero(~active, as_tuple=False).flatten()
|
| 438 |
+
pick = remaining[torch.randperm(remaining.numel(), device=self.device)[:k_explore]]
|
| 439 |
+
active[pick] = True
|
| 440 |
+
return active
|
| 441 |
+
|
| 442 |
+
def choose_pre_backward(self, step: int) -> None:
|
| 443 |
+
self.step_index = step
|
| 444 |
+
if step < self.warmup_steps:
|
| 445 |
+
self._set_active(torch.ones(self.n_blocks, dtype=torch.bool, device=self.device))
|
| 446 |
+
return
|
| 447 |
+
|
| 448 |
+
if self.policy == "oracle_current":
|
| 449 |
+
# Oracle cannot choose until the dense audit gradient is known.
|
| 450 |
+
self._set_active(torch.zeros(self.n_blocks, dtype=torch.bool, device=self.device))
|
| 451 |
+
return
|
| 452 |
+
|
| 453 |
+
if self.policy == "random":
|
| 454 |
+
self._set_active(self._sample_exploit_explore(torch.rand(self.n_blocks, device=self.device)))
|
| 455 |
+
return
|
| 456 |
+
|
| 457 |
+
if self.policy == "stale_current":
|
| 458 |
+
self._set_active(self._topk_mask(self.last_full_mass, self.active_fraction))
|
| 459 |
+
return
|
| 460 |
+
|
| 461 |
+
if self.policy == "predicted_magnitude":
|
| 462 |
+
self._set_active(self._sample_exploit_explore(self.predicted_mass))
|
| 463 |
+
return
|
| 464 |
+
|
| 465 |
+
if self.policy == "ucb_magnitude":
|
| 466 |
+
t = max(1, step - self.warmup_steps + 1)
|
| 467 |
+
log_term = torch.log(torch.tensor(float(t + 2), device=self.device))
|
| 468 |
+
bonus_scale = torch.clamp(self.global_mass_ema, min=1e-8)
|
| 469 |
+
bonus = self.ucb_alpha * bonus_scale * torch.sqrt(log_term / (self.observed_count + 1.0))
|
| 470 |
+
self._set_active(self._sample_exploit_explore(self.predicted_mass + bonus))
|
| 471 |
+
return
|
| 472 |
+
|
| 473 |
+
raise ValueError(f"Unknown policy: {self.policy}")
|
| 474 |
+
|
| 475 |
+
@torch.no_grad()
|
| 476 |
+
def current_gradient_mass_from_grads(self) -> torch.Tensor:
|
| 477 |
+
mass = torch.zeros(self.n_blocks, device=self.device)
|
| 478 |
+
for m, ids in self.module_to_ids.items():
|
| 479 |
+
if m.weight.grad is None:
|
| 480 |
+
continue
|
| 481 |
+
row_sq = m.weight.grad.square().sum(dim=1)
|
| 482 |
+
if m.bias is not None and m.bias.grad is not None:
|
| 483 |
+
row_sq = row_sq + m.bias.grad.square()
|
| 484 |
+
mass[ids] = torch.sqrt(row_sq + 1e-30)
|
| 485 |
+
return mass
|
| 486 |
+
|
| 487 |
+
@torch.no_grad()
|
| 488 |
+
def audit_and_update_from_mass(self, step: int, mass: torch.Tensor) -> Dict[str, float]:
|
| 489 |
+
if step < self.warmup_steps:
|
| 490 |
+
active = torch.ones(self.n_blocks, dtype=torch.bool, device=self.device)
|
| 491 |
+
self._set_active(active)
|
| 492 |
+
elif self.policy == "oracle_current":
|
| 493 |
+
active = self._topk_mask(mass, self.active_fraction)
|
| 494 |
+
self._set_active(active)
|
| 495 |
+
else:
|
| 496 |
+
active = self.active
|
| 497 |
+
|
| 498 |
+
true_sq = mass.square().sum()
|
| 499 |
+
approx_sq = mass[active].square().sum()
|
| 500 |
+
cosine = float((torch.sqrt(approx_sq + 1e-30) / torch.sqrt(true_sq + 1e-30)).item())
|
| 501 |
+
|
| 502 |
+
oracle_mask = self._topk_mask(mass, self.active_fraction)
|
| 503 |
+
jacc = self._jaccard(active, oracle_mask)
|
| 504 |
+
stability = self._jaccard(active, self.prev_active)
|
| 505 |
+
self.prev_active = active.clone()
|
| 506 |
+
|
| 507 |
+
k20 = max(1, int(0.2 * self.n_blocks))
|
| 508 |
+
sorted_mass = torch.sort(mass, descending=True).values
|
| 509 |
+
top20_mass = float((sorted_mass[:k20].sum() / (sorted_mass.sum() + 1e-12)).item())
|
| 510 |
+
|
| 511 |
+
new_active = active & (self.observed_count == 0)
|
| 512 |
+
|
| 513 |
+
# Practical rule: update predicted statistics only for active/observed rows.
|
| 514 |
+
self.predicted_mass.mul_(self.unobserved_decay)
|
| 515 |
+
observed = active
|
| 516 |
+
if bool(observed.any().item()):
|
| 517 |
+
obs_mass = mass[observed]
|
| 518 |
+
first_seen = self.observed_count[observed] == 0
|
| 519 |
+
ema_mass = self.mass_beta * self.predicted_mass[observed] + (1.0 - self.mass_beta) * obs_mass
|
| 520 |
+
self.predicted_mass[observed] = torch.where(first_seen, obs_mass, ema_mass)
|
| 521 |
+
self.observed_count[observed] += 1.0
|
| 522 |
+
self.global_mass_ema = self.mass_beta * self.global_mass_ema + (1.0 - self.mass_beta) * obs_mass.mean()
|
| 523 |
+
|
| 524 |
+
# Dense audit signal; only stale_current is allowed to use this for selection.
|
| 525 |
+
self.last_full_mass = mass.detach().clone()
|
| 526 |
+
|
| 527 |
+
return {
|
| 528 |
+
"cosine": cosine,
|
| 529 |
+
"norm_ratio": cosine,
|
| 530 |
+
"top20_mass": top20_mass,
|
| 531 |
+
"jacc_oracle": jacc,
|
| 532 |
+
"stability": stability,
|
| 533 |
+
"active_fraction_real": float(active.float().mean().item()),
|
| 534 |
+
"coverage": float((self.observed_count > 0).float().mean().item()),
|
| 535 |
+
"avg_obs_count": float(self.observed_count.mean().item()),
|
| 536 |
+
"new_active_fraction": float(new_active.float().mean().item()),
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
def row_mask_for(self, module: SparseLinear) -> Optional[torch.Tensor]:
|
| 540 |
+
return self.row_masks.get(module)
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
# -----------------------------
|
| 544 |
+
# Masked Adam
|
| 545 |
+
# -----------------------------
|
| 546 |
+
|
| 547 |
+
class MaskedAdam:
|
| 548 |
+
def __init__(
|
| 549 |
+
self,
|
| 550 |
+
model: nn.Module,
|
| 551 |
+
masker: Optional[RowMasker],
|
| 552 |
+
lr: float,
|
| 553 |
+
betas=(0.9, 0.95),
|
| 554 |
+
eps=1e-8,
|
| 555 |
+
weight_decay=0.0,
|
| 556 |
+
freeze_non_linear_when_sparse: bool = False,
|
| 557 |
+
):
|
| 558 |
+
self.model = model
|
| 559 |
+
self.masker = masker
|
| 560 |
+
self.lr = lr
|
| 561 |
+
self.beta1, self.beta2 = betas
|
| 562 |
+
self.eps = eps
|
| 563 |
+
self.weight_decay = weight_decay
|
| 564 |
+
self.freeze_non_linear_when_sparse = freeze_non_linear_when_sparse
|
| 565 |
+
self.state: Dict[nn.Parameter, Dict[str, torch.Tensor]] = {}
|
| 566 |
+
self.linear_param: Dict[nn.Parameter, Tuple[SparseLinear, str]] = {}
|
| 567 |
+
for _, m in named_sparse_linear_modules(model):
|
| 568 |
+
self.linear_param[m.weight] = (m, "weight")
|
| 569 |
+
if m.bias is not None:
|
| 570 |
+
self.linear_param[m.bias] = (m, "bias")
|
| 571 |
+
|
| 572 |
+
def zero_grad(self) -> None:
|
| 573 |
+
for p in self.model.parameters():
|
| 574 |
+
p.grad = None
|
| 575 |
+
|
| 576 |
+
@torch.no_grad()
|
| 577 |
+
def step(self) -> None:
|
| 578 |
+
for p in self.model.parameters():
|
| 579 |
+
if p.grad is None:
|
| 580 |
+
continue
|
| 581 |
+
if self.masker is not None and self.freeze_non_linear_when_sparse and p not in self.linear_param:
|
| 582 |
+
continue
|
| 583 |
+
|
| 584 |
+
if p not in self.state:
|
| 585 |
+
self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
|
| 586 |
+
m = self.state[p]["m"]
|
| 587 |
+
v = self.state[p]["v"]
|
| 588 |
+
g = p.grad
|
| 589 |
+
if self.weight_decay:
|
| 590 |
+
g = g.add(p, alpha=self.weight_decay)
|
| 591 |
+
|
| 592 |
+
row_mask = None
|
| 593 |
+
if self.masker is not None and p in self.linear_param:
|
| 594 |
+
module, kind = self.linear_param[p]
|
| 595 |
+
base = self.masker.row_mask_for(module)
|
| 596 |
+
if base is not None:
|
| 597 |
+
row_mask = base.view(-1, *([1] * (p.ndim - 1))) if kind == "weight" else base
|
| 598 |
+
|
| 599 |
+
if row_mask is None:
|
| 600 |
+
m.mul_(self.beta1).add_(g, alpha=1.0 - self.beta1)
|
| 601 |
+
v.mul_(self.beta2).addcmul_(g, g, value=1.0 - self.beta2)
|
| 602 |
+
p.add_(m / (torch.sqrt(v) + self.eps), alpha=-self.lr)
|
| 603 |
+
else:
|
| 604 |
+
mask = row_mask.expand_as(p)
|
| 605 |
+
if not bool(mask.any().item()):
|
| 606 |
+
continue
|
| 607 |
+
new_m = self.beta1 * m + (1.0 - self.beta1) * g
|
| 608 |
+
new_v = self.beta2 * v + (1.0 - self.beta2) * g * g
|
| 609 |
+
m[mask] = new_m[mask]
|
| 610 |
+
v[mask] = new_v[mask]
|
| 611 |
+
update = m / (torch.sqrt(v) + self.eps)
|
| 612 |
+
p[mask] = p[mask] - self.lr * update[mask]
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
# -----------------------------
|
| 616 |
+
# Training utilities
|
| 617 |
+
# -----------------------------
|
| 618 |
+
|
| 619 |
+
@torch.no_grad()
|
| 620 |
+
def estimate_loss(model: nn.Module, corpus: CharCorpus, batch_size: int, eval_iters: int, seed: int) -> Dict[str, float]:
|
| 621 |
+
model.eval()
|
| 622 |
+
configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
|
| 623 |
+
out = {}
|
| 624 |
+
for split in ["train", "val"]:
|
| 625 |
+
losses = []
|
| 626 |
+
gen = make_cpu_generator(seed + (0 if split == "train" else 100000))
|
| 627 |
+
for _ in range(eval_iters):
|
| 628 |
+
x, y = corpus.get_batch(split, batch_size, generator=gen)
|
| 629 |
+
_, loss = model(x, y)
|
| 630 |
+
losses.append(float(loss.item()))
|
| 631 |
+
out[split] = sum(losses) / len(losses)
|
| 632 |
+
model.train()
|
| 633 |
+
return out
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
def dense_audit_pass(model: nn.Module, corpus_batch: Tuple[torch.Tensor, torch.Tensor], opt: MaskedAdam, masker: RowMasker) -> torch.Tensor:
|
| 637 |
+
x, y = corpus_batch
|
| 638 |
+
configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
|
| 639 |
+
opt.zero_grad()
|
| 640 |
+
_, audit_loss = model(x, y)
|
| 641 |
+
audit_loss.backward()
|
| 642 |
+
mass = masker.current_gradient_mass_from_grads()
|
| 643 |
+
opt.zero_grad()
|
| 644 |
+
return mass
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
def sparse_training_backward(
|
| 648 |
+
model: nn.Module,
|
| 649 |
+
corpus_batch: Tuple[torch.Tensor, torch.Tensor],
|
| 650 |
+
opt: MaskedAdam,
|
| 651 |
+
masker: Optional[RowMasker],
|
| 652 |
+
backward_mode: Optional[BackwardMode],
|
| 653 |
+
) -> float:
|
| 654 |
+
x, y = corpus_batch
|
| 655 |
+
opt.zero_grad()
|
| 656 |
+
|
| 657 |
+
if masker is None or backward_mode is None or backward_mode == "masked_optimizer":
|
| 658 |
+
configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
|
| 659 |
+
else:
|
| 660 |
+
configure_sparse_linears(model, masker=masker, enabled=True, backward_mode=backward_mode)
|
| 661 |
+
|
| 662 |
+
_, loss = model(x, y)
|
| 663 |
+
loss.backward()
|
| 664 |
+
configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
|
| 665 |
+
return float(loss.item())
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
def train_run(
|
| 669 |
+
corpus: CharCorpus,
|
| 670 |
+
args: argparse.Namespace,
|
| 671 |
+
policy: Optional[Policy],
|
| 672 |
+
backward_mode: Optional[BackwardMode],
|
| 673 |
+
active_fraction: float,
|
| 674 |
+
warmup_steps: int,
|
| 675 |
+
explore_fraction: float,
|
| 676 |
+
seed_offset: int,
|
| 677 |
+
) -> Dict[str, float | str]:
|
| 678 |
+
# Same model initialization and same minibatch sequence for every run by default.
|
| 679 |
+
set_seed(args.seed + (seed_offset if args.unpaired_seeds else 0))
|
| 680 |
+
data_gen = make_cpu_generator(args.seed + 12345)
|
| 681 |
+
|
| 682 |
+
dev = corpus.device
|
| 683 |
+
model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev)
|
| 684 |
+
|
| 685 |
+
masker = None
|
| 686 |
+
if policy is not None:
|
| 687 |
+
masker = RowMasker(
|
| 688 |
+
model=model,
|
| 689 |
+
policy=policy,
|
| 690 |
+
active_fraction=active_fraction,
|
| 691 |
+
explore_fraction=explore_fraction,
|
| 692 |
+
mass_beta=args.mass_beta,
|
| 693 |
+
unobserved_decay=args.unobserved_decay,
|
| 694 |
+
warmup_steps=warmup_steps,
|
| 695 |
+
ucb_alpha=args.ucb_alpha,
|
| 696 |
+
mass_init=args.mass_init,
|
| 697 |
+
device=dev,
|
| 698 |
+
)
|
| 699 |
+
opt = MaskedAdam(
|
| 700 |
+
model,
|
| 701 |
+
masker,
|
| 702 |
+
lr=args.lr,
|
| 703 |
+
weight_decay=args.weight_decay,
|
| 704 |
+
freeze_non_linear_when_sparse=args.freeze_non_linear_when_sparse,
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
sums = {
|
| 708 |
+
"cosine": 0.0,
|
| 709 |
+
"norm_ratio": 0.0,
|
| 710 |
+
"top20_mass": 0.0,
|
| 711 |
+
"jacc_oracle": 0.0,
|
| 712 |
+
"stability": 0.0,
|
| 713 |
+
"active_fraction_real": 0.0,
|
| 714 |
+
"coverage": 0.0,
|
| 715 |
+
"avg_obs_count": 0.0,
|
| 716 |
+
"new_active_fraction": 0.0,
|
| 717 |
+
}
|
| 718 |
+
count = 0
|
| 719 |
+
|
| 720 |
+
for step in range(args.steps):
|
| 721 |
+
batch = corpus.get_batch("train", args.batch_size, generator=data_gen)
|
| 722 |
+
|
| 723 |
+
if masker is None:
|
| 724 |
+
loss_value = sparse_training_backward(model, batch, opt, masker=None, backward_mode=None)
|
| 725 |
+
opt.step()
|
| 726 |
+
else:
|
| 727 |
+
masker.choose_pre_backward(step)
|
| 728 |
+
full_mass = dense_audit_pass(model, batch, opt, masker)
|
| 729 |
+
metrics = masker.audit_and_update_from_mass(step, full_mass)
|
| 730 |
+
if step >= warmup_steps:
|
| 731 |
+
for k in sums:
|
| 732 |
+
sums[k] += metrics[k]
|
| 733 |
+
count += 1
|
| 734 |
+
|
| 735 |
+
loss_value = sparse_training_backward(model, batch, opt, masker=masker, backward_mode=backward_mode)
|
| 736 |
+
opt.step()
|
| 737 |
+
|
| 738 |
+
if args.verbose and (step % args.eval_interval == 0 or step == args.steps - 1):
|
| 739 |
+
losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters, seed=args.seed + 555)
|
| 740 |
+
name = "dense" if policy is None else f"{policy}/{backward_mode}"
|
| 741 |
+
print(
|
| 742 |
+
f"{name:38s} step={step:5d} warm={warmup_steps:4d} explore={explore_fraction:.2f} "
|
| 743 |
+
f"loss={loss_value:.4f} train={losses['train']:.4f} val={losses['val']:.4f}"
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters, seed=args.seed + 999)
|
| 747 |
+
row: Dict[str, float | str] = {
|
| 748 |
+
"run": "dense_baseline" if policy is None else policy,
|
| 749 |
+
"mode": "dense" if backward_mode is None else backward_mode,
|
| 750 |
+
"target_active": 1.0 if policy is None else active_fraction,
|
| 751 |
+
"warmup": warmup_steps,
|
| 752 |
+
"explore": explore_fraction if policy is not None else 0.0,
|
| 753 |
+
"train_loss": losses["train"],
|
| 754 |
+
"val_loss": losses["val"],
|
| 755 |
+
}
|
| 756 |
+
if masker is None or count == 0:
|
| 757 |
+
row.update({
|
| 758 |
+
"cosine": float("nan"),
|
| 759 |
+
"norm_ratio": float("nan"),
|
| 760 |
+
"top20_mass": float("nan"),
|
| 761 |
+
"jacc_oracle": float("nan"),
|
| 762 |
+
"stability": float("nan"),
|
| 763 |
+
"active_fraction_real": 1.0,
|
| 764 |
+
"coverage": float("nan"),
|
| 765 |
+
"avg_obs_count": float("nan"),
|
| 766 |
+
"new_active_fraction": float("nan"),
|
| 767 |
+
})
|
| 768 |
+
else:
|
| 769 |
+
for k, v in sums.items():
|
| 770 |
+
row[k] = v / count
|
| 771 |
+
return row
|
| 772 |
+
|
| 773 |
+
|
| 774 |
+
def print_summary(rows: List[Dict[str, float | str]]) -> None:
|
| 775 |
+
print("\nSummary")
|
| 776 |
+
header = (
|
| 777 |
+
f"{'run':>22s} {'mode':>19s} {'target':>7s} {'actual':>7s} {'warm':>5s} {'expl':>5s} "
|
| 778 |
+
f"{'val':>8s} {'train':>8s} {'cos':>7s} {'top20':>7s} {'jacc':>7s} "
|
| 779 |
+
f"{'stable':>7s} {'cover':>7s} {'new':>7s}"
|
| 780 |
+
)
|
| 781 |
+
print(header)
|
| 782 |
+
print("-" * len(header))
|
| 783 |
+
for r in rows:
|
| 784 |
+
print(
|
| 785 |
+
f"{str(r['run']):>22s} "
|
| 786 |
+
f"{str(r['mode']):>19s} "
|
| 787 |
+
f"{float(r['target_active']):7.3f} "
|
| 788 |
+
f"{float(r['active_fraction_real']):7.3f} "
|
| 789 |
+
f"{int(float(r['warmup'])):5d} "
|
| 790 |
+
f"{float(r['explore']):5.2f} "
|
| 791 |
+
f"{float(r['val_loss']):8.4f} "
|
| 792 |
+
f"{float(r['train_loss']):8.4f} "
|
| 793 |
+
f"{float(r['cosine']):7.3f} "
|
| 794 |
+
f"{float(r['top20_mass']):7.3f} "
|
| 795 |
+
f"{float(r['jacc_oracle']):7.3f} "
|
| 796 |
+
f"{float(r['stability']):7.3f} "
|
| 797 |
+
f"{float(r['coverage']):7.3f} "
|
| 798 |
+
f"{float(r['new_active_fraction']):7.3f}"
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
def parse_args() -> argparse.Namespace:
|
| 803 |
+
p = argparse.ArgumentParser()
|
| 804 |
+
p.add_argument("--text_path", type=str, default=None)
|
| 805 |
+
p.add_argument("--synthetic_sentences", type=int, default=12000)
|
| 806 |
+
p.add_argument("--steps", type=int, default=1000)
|
| 807 |
+
p.add_argument("--quick", action="store_true")
|
| 808 |
+
p.add_argument("--batch_size", type=int, default=32)
|
| 809 |
+
p.add_argument("--block_size", type=int, default=64)
|
| 810 |
+
p.add_argument("--n_layer", type=int, default=2)
|
| 811 |
+
p.add_argument("--n_head", type=int, default=4)
|
| 812 |
+
p.add_argument("--n_embd", type=int, default=64)
|
| 813 |
+
p.add_argument("--dropout", type=float, default=0.0)
|
| 814 |
+
p.add_argument("--lr", type=float, default=3e-4)
|
| 815 |
+
p.add_argument("--weight_decay", type=float, default=0.0)
|
| 816 |
+
p.add_argument("--active_fractions", type=float, nargs="+", default=[0.05, 0.02])
|
| 817 |
+
p.add_argument("--policies", type=str, nargs="+", default=["oracle_current", "predicted_magnitude", "random"])
|
| 818 |
+
p.add_argument(
|
| 819 |
+
"--backward_modes",
|
| 820 |
+
type=str,
|
| 821 |
+
nargs="+",
|
| 822 |
+
default=["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"],
|
| 823 |
+
)
|
| 824 |
+
p.add_argument("--explore_fractions", type=float, nargs="+", default=[0.0])
|
| 825 |
+
p.add_argument("--warmup_steps_list", type=int, nargs="+", default=[5])
|
| 826 |
+
p.add_argument("--mass_beta", type=float, default=0.95)
|
| 827 |
+
p.add_argument("--unobserved_decay", type=float, default=1.0)
|
| 828 |
+
p.add_argument("--mass_init", type=float, default=0.0)
|
| 829 |
+
p.add_argument("--ucb_alpha", type=float, default=1.0)
|
| 830 |
+
p.add_argument("--freeze_non_linear_when_sparse", action="store_true")
|
| 831 |
+
p.add_argument("--eval_interval", type=int, default=200)
|
| 832 |
+
p.add_argument("--eval_iters", type=int, default=20)
|
| 833 |
+
p.add_argument("--seed", type=int, default=7)
|
| 834 |
+
p.add_argument("--unpaired_seeds", action="store_true", help="Use different init seeds per run instead of paired seeds.")
|
| 835 |
+
p.add_argument("--verbose", action="store_true")
|
| 836 |
+
return p.parse_args()
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
def main() -> None:
|
| 840 |
+
args = parse_args()
|
| 841 |
+
if args.quick:
|
| 842 |
+
args.steps = 40
|
| 843 |
+
args.eval_iters = 2
|
| 844 |
+
args.batch_size = 8
|
| 845 |
+
args.block_size = 32
|
| 846 |
+
args.n_layer = 1
|
| 847 |
+
args.n_embd = 32
|
| 848 |
+
args.n_head = 4
|
| 849 |
+
args.synthetic_sentences = 1200
|
| 850 |
+
args.active_fractions = [0.05]
|
| 851 |
+
args.policies = ["predicted_magnitude", "random"]
|
| 852 |
+
args.backward_modes = ["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
|
| 853 |
+
args.explore_fractions = [0.0]
|
| 854 |
+
args.warmup_steps_list = [5]
|
| 855 |
+
|
| 856 |
+
valid_policies = {"predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"}
|
| 857 |
+
valid_modes = {"masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"}
|
| 858 |
+
for pol in args.policies:
|
| 859 |
+
if pol not in valid_policies:
|
| 860 |
+
raise ValueError(f"Unknown policy {pol!r}. Valid policies: {sorted(valid_policies)}")
|
| 861 |
+
for mode in args.backward_modes:
|
| 862 |
+
if mode not in valid_modes:
|
| 863 |
+
raise ValueError(f"Unknown backward mode {mode!r}. Valid modes: {sorted(valid_modes)}")
|
| 864 |
+
|
| 865 |
+
set_seed(args.seed)
|
| 866 |
+
dev = device()
|
| 867 |
+
print(f"device={dev}")
|
| 868 |
+
corpus = CharCorpus(load_text(args), args.block_size, dev)
|
| 869 |
+
print(f"vocab_size={corpus.vocab_size} train_tokens={len(corpus.train_data)} val_tokens={len(corpus.val_data)}")
|
| 870 |
+
print(f"policies={args.policies}")
|
| 871 |
+
print(f"backward_modes={args.backward_modes}")
|
| 872 |
+
print(f"active_fractions={args.active_fractions}")
|
| 873 |
+
print(f"warmup_steps_list={args.warmup_steps_list} explore_fractions={args.explore_fractions}")
|
| 874 |
+
print(f"mass_init={args.mass_init} mass_beta={args.mass_beta} ucb_alpha={args.ucb_alpha}")
|
| 875 |
+
print(f"paired_seeds={not args.unpaired_seeds}")
|
| 876 |
+
|
| 877 |
+
tmp_model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev)
|
| 878 |
+
total_params, linear_params, linear_frac = parameter_fractions(tmp_model)
|
| 879 |
+
del tmp_model
|
| 880 |
+
print(f"params total={total_params} linear={linear_params} linear_fraction={linear_frac:.3f}")
|
| 881 |
+
if args.freeze_non_linear_when_sparse:
|
| 882 |
+
print("freeze_non_linear_when_sparse=True: embeddings/layernorm/etc. are frozen in sparse runs")
|
| 883 |
+
else:
|
| 884 |
+
print("freeze_non_linear_when_sparse=False: non-Linear params are still updated densely")
|
| 885 |
+
|
| 886 |
+
if args.dropout != 0.0:
|
| 887 |
+
print("warning: dropout is nonzero; dense audit and sparse training passes may see different dropout masks")
|
| 888 |
+
|
| 889 |
+
rows: List[Dict[str, float | str]] = []
|
| 890 |
+
print("\nRunning dense baseline")
|
| 891 |
+
rows.append(
|
| 892 |
+
train_run(
|
| 893 |
+
corpus,
|
| 894 |
+
args,
|
| 895 |
+
policy=None,
|
| 896 |
+
backward_mode=None,
|
| 897 |
+
active_fraction=1.0,
|
| 898 |
+
warmup_steps=0,
|
| 899 |
+
explore_fraction=0.0,
|
| 900 |
+
seed_offset=0,
|
| 901 |
+
)
|
| 902 |
+
)
|
| 903 |
+
|
| 904 |
+
seed_offset = 100
|
| 905 |
+
for mode in args.backward_modes:
|
| 906 |
+
for af in args.active_fractions:
|
| 907 |
+
for pol in args.policies:
|
| 908 |
+
explore_values = args.explore_fractions if pol in {"predicted_magnitude", "ucb_magnitude"} else [0.0]
|
| 909 |
+
for warmup in args.warmup_steps_list:
|
| 910 |
+
for explore in explore_values:
|
| 911 |
+
print(
|
| 912 |
+
f"\nRunning mode={mode}, policy={pol}, "
|
| 913 |
+
f"active_fraction={af:.3f}, warmup={warmup}, explore={explore:.2f}"
|
| 914 |
+
)
|
| 915 |
+
rows.append(
|
| 916 |
+
train_run(
|
| 917 |
+
corpus,
|
| 918 |
+
args,
|
| 919 |
+
policy=pol, # type: ignore[arg-type]
|
| 920 |
+
backward_mode=mode, # type: ignore[arg-type]
|
| 921 |
+
active_fraction=af,
|
| 922 |
+
warmup_steps=warmup,
|
| 923 |
+
explore_fraction=explore,
|
| 924 |
+
seed_offset=seed_offset,
|
| 925 |
+
)
|
| 926 |
+
)
|
| 927 |
+
seed_offset += 1
|
| 928 |
+
|
| 929 |
+
print_summary(rows)
|
| 930 |
+
|
| 931 |
+
print("\nNotes")
|
| 932 |
+
print(" masked_optimizer is the v7-style dense-backward simulation control.")
|
| 933 |
+
print(" sparse_dW_full_dX uses custom Linear backward: sparse weight/bias grads, full input gradient.")
|
| 934 |
+
print(" sparse_dW_sparse_dX uses custom Linear backward: sparse weight/bias grads and sparse input gradient.")
|
| 935 |
+
print(" oracle_current uses dense audit gradients to choose rows; it is an upper bound.")
|
| 936 |
+
print(" predicted_magnitude uses EMA mass from active/observed rows only.")
|
| 937 |
+
print(" random is the sparse-support control.")
|
| 938 |
+
print(" dense audit gradients are still computed every step for metrics/control; this is not a speed benchmark.")
|
| 939 |
+
print(" The key comparison is masked_optimizer vs sparse_dW_full_dX. If they match, the v7 effect survives real dW sparsification.")
|
| 940 |
+
|
| 941 |
+
|
| 942 |
+
if __name__ == "__main__":
|
| 943 |
+
main()
|
experiments/sparse_transformer_v9.py
ADDED
|
@@ -0,0 +1,1042 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sparse Transformer v9: no-audit sparse training after dense warmup.
|
| 3 |
+
|
| 4 |
+
v8 proved that the row-sparse mask can be moved into a custom Linear backward.
|
| 5 |
+
v9 removes the remaining dense-audit crutch.
|
| 6 |
+
|
| 7 |
+
Default behavior
|
| 8 |
+
----------------
|
| 9 |
+
1. Run a short dense warmup, usually 5 steps.
|
| 10 |
+
2. Initialize the EMA row-importance predictor from those dense warmup gradients.
|
| 11 |
+
3. After warmup, choose active rows from the predictor.
|
| 12 |
+
4. Train using sparse backward.
|
| 13 |
+
5. Update EMA statistics only from rows that were actually active/observed.
|
| 14 |
+
6. Do not compute dense gradients unless --audit_every > 0.
|
| 15 |
+
|
| 16 |
+
Audit behavior
|
| 17 |
+
--------------
|
| 18 |
+
--audit_every 0
|
| 19 |
+
No dense audit after warmup. Cosine/Jaccard/top20 are unavailable and show as nan.
|
| 20 |
+
|
| 21 |
+
--audit_every N
|
| 22 |
+
Every N steps, run an extra dense backward pass on the same batch only to
|
| 23 |
+
measure cosine/top20/Jaccard. The audit is NOT used to update the selector,
|
| 24 |
+
except for oracle_current, which is explicitly an upper-bound control.
|
| 25 |
+
|
| 26 |
+
This is still not a wall-clock benchmark on vanilla PyTorch/MPS/CPU. The custom
|
| 27 |
+
backward uses indexing and ordinary PyTorch matmuls. The goal is to verify that
|
| 28 |
+
the method survives without dense information after warmup.
|
| 29 |
+
|
| 30 |
+
Examples
|
| 31 |
+
--------
|
| 32 |
+
No-audit practical run:
|
| 33 |
+
python3 sparse_transformer_v9.py \
|
| 34 |
+
--device mps \
|
| 35 |
+
--steps 2000 \
|
| 36 |
+
--active_fractions 0.05 0.02 \
|
| 37 |
+
--warmup_steps_list 5 \
|
| 38 |
+
--policies predicted_magnitude random \
|
| 39 |
+
--backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
|
| 40 |
+
--audit_every 0
|
| 41 |
+
|
| 42 |
+
Occasional audit for measurement only:
|
| 43 |
+
python3 sparse_transformer_v9.py \
|
| 44 |
+
--steps 2000 \
|
| 45 |
+
--active_fractions 0.05 0.02 \
|
| 46 |
+
--warmup_steps_list 5 \
|
| 47 |
+
--policies predicted_magnitude random \
|
| 48 |
+
--backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
|
| 49 |
+
--audit_every 100
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
from __future__ import annotations
|
| 54 |
+
|
| 55 |
+
import argparse
|
| 56 |
+
import math
|
| 57 |
+
import random
|
| 58 |
+
from typing import Dict, List, Literal, Optional, Tuple
|
| 59 |
+
|
| 60 |
+
import torch
|
| 61 |
+
|
| 62 |
+
torch.set_num_threads(1)
|
| 63 |
+
import torch.nn as nn
|
| 64 |
+
import torch.nn.functional as F
|
| 65 |
+
|
| 66 |
+
Policy = Literal["predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"]
|
| 67 |
+
BackwardMode = Literal["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# -----------------------------
|
| 71 |
+
# Reproducibility and device
|
| 72 |
+
# -----------------------------
|
| 73 |
+
|
| 74 |
+
def set_seed(seed: int) -> None:
|
| 75 |
+
random.seed(seed)
|
| 76 |
+
torch.manual_seed(seed)
|
| 77 |
+
if torch.cuda.is_available():
|
| 78 |
+
torch.cuda.manual_seed_all(seed)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def default_device() -> str:
|
| 82 |
+
if torch.cuda.is_available():
|
| 83 |
+
return "cuda"
|
| 84 |
+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 85 |
+
return "mps"
|
| 86 |
+
return "cpu"
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def make_cpu_generator(seed: int) -> torch.Generator:
|
| 90 |
+
gen = torch.Generator(device="cpu")
|
| 91 |
+
gen.manual_seed(seed)
|
| 92 |
+
return gen
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# -----------------------------
|
| 96 |
+
# Data
|
| 97 |
+
# -----------------------------
|
| 98 |
+
|
| 99 |
+
def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
|
| 100 |
+
rng = random.Random(seed)
|
| 101 |
+
names = ["ada", "turing", "grace", "lovelace", "noether", "shannon", "hopper", "gauss"]
|
| 102 |
+
verbs = ["builds", "tests", "traces", "compresses", "predicts", "routes", "writes", "measures"]
|
| 103 |
+
objects = ["signals", "gradients", "tokens", "circuits", "features", "masks", "errors", "states"]
|
| 104 |
+
adverbs = ["quietly", "boldly", "slowly", "quickly", "cleanly", "strangely", "carefully"]
|
| 105 |
+
clauses = [
|
| 106 |
+
"when the loss falls",
|
| 107 |
+
"after the mask shifts",
|
| 108 |
+
"before the model answers",
|
| 109 |
+
"while the signal drifts",
|
| 110 |
+
"if the pattern repeats",
|
| 111 |
+
"because the tail is noisy",
|
| 112 |
+
]
|
| 113 |
+
symbols = ["alpha", "beta", "gamma", "delta", "omega", "sigma"]
|
| 114 |
+
|
| 115 |
+
lines: List[str] = []
|
| 116 |
+
for _ in range(n_sentences):
|
| 117 |
+
t = rng.randrange(6)
|
| 118 |
+
if t == 0:
|
| 119 |
+
line = f"{rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} {rng.choice(adverbs)}."
|
| 120 |
+
elif t == 1:
|
| 121 |
+
line = f"{rng.choice(clauses)}, {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)}."
|
| 122 |
+
elif t == 2:
|
| 123 |
+
a, b = rng.sample(symbols, 2)
|
| 124 |
+
line = f"rule {a}: {rng.choice(objects)} -> {rng.choice(objects)}; rule {b}: {rng.choice(objects)} -> {rng.choice(objects)}."
|
| 125 |
+
elif t == 3:
|
| 126 |
+
line = f"the {rng.choice(objects)} {rng.choice(verbs)} the {rng.choice(objects)} {rng.choice(adverbs)}."
|
| 127 |
+
elif t == 4:
|
| 128 |
+
seq = " ".join(rng.choice(symbols) for _ in range(rng.randint(2, 7)))
|
| 129 |
+
line = f"sequence {seq} ends when {rng.choice(names)} {rng.choice(verbs)}."
|
| 130 |
+
else:
|
| 131 |
+
line = f"if {rng.choice(objects)} rise then {rng.choice(names)} {rng.choice(verbs)} {rng.choice(objects)} else wait."
|
| 132 |
+
lines.append(line)
|
| 133 |
+
return "\n".join(lines) + "\n"
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class CharCorpus:
|
| 137 |
+
def __init__(self, text: str, block_size: int, device: str):
|
| 138 |
+
chars = sorted(set(text))
|
| 139 |
+
self.stoi = {ch: i for i, ch in enumerate(chars)}
|
| 140 |
+
self.itos = {i: ch for ch, i in self.stoi.items()}
|
| 141 |
+
self.vocab_size = len(chars)
|
| 142 |
+
self.block_size = block_size
|
| 143 |
+
self.device = device
|
| 144 |
+
|
| 145 |
+
data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
|
| 146 |
+
split = int(0.9 * len(data))
|
| 147 |
+
self.train_data = data[:split]
|
| 148 |
+
self.val_data = data[split:]
|
| 149 |
+
|
| 150 |
+
def get_batch(
|
| 151 |
+
self,
|
| 152 |
+
split: str,
|
| 153 |
+
batch_size: int,
|
| 154 |
+
generator: Optional[torch.Generator] = None,
|
| 155 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 156 |
+
data = self.train_data if split == "train" else self.val_data
|
| 157 |
+
max_start = len(data) - self.block_size - 1
|
| 158 |
+
if max_start <= 0:
|
| 159 |
+
raise ValueError("Corpus too small for block_size")
|
| 160 |
+
ix = torch.randint(max_start, (batch_size,), generator=generator)
|
| 161 |
+
x = torch.stack([data[i : i + self.block_size] for i in ix])
|
| 162 |
+
y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
|
| 163 |
+
return x.to(self.device), y.to(self.device)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def load_text(args: argparse.Namespace) -> str:
|
| 167 |
+
if args.text_path:
|
| 168 |
+
with open(args.text_path, "r", encoding="utf-8") as f:
|
| 169 |
+
return f.read()
|
| 170 |
+
return make_synthetic_corpus(args.synthetic_sentences, args.seed)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# -----------------------------
|
| 174 |
+
# Sparse Linear autograd
|
| 175 |
+
# -----------------------------
|
| 176 |
+
|
| 177 |
+
class MaskedLinearFunction(torch.autograd.Function):
|
| 178 |
+
@staticmethod
|
| 179 |
+
def forward( # type: ignore[override]
|
| 180 |
+
ctx,
|
| 181 |
+
x: torch.Tensor,
|
| 182 |
+
weight: torch.Tensor,
|
| 183 |
+
bias: Optional[torch.Tensor],
|
| 184 |
+
active_rows: torch.Tensor,
|
| 185 |
+
sparse_dx: bool,
|
| 186 |
+
) -> torch.Tensor:
|
| 187 |
+
ctx.save_for_backward(x, weight, active_rows)
|
| 188 |
+
ctx.has_bias = bias is not None
|
| 189 |
+
ctx.sparse_dx = bool(sparse_dx)
|
| 190 |
+
return F.linear(x, weight, bias)
|
| 191 |
+
|
| 192 |
+
@staticmethod
|
| 193 |
+
def backward(ctx, grad_y: torch.Tensor): # type: ignore[override]
|
| 194 |
+
x, weight, active_rows = ctx.saved_tensors
|
| 195 |
+
sparse_dx = bool(ctx.sparse_dx)
|
| 196 |
+
has_bias = bool(ctx.has_bias)
|
| 197 |
+
|
| 198 |
+
x_shape = x.shape
|
| 199 |
+
x_flat = x.reshape(-1, x.shape[-1])
|
| 200 |
+
gy_flat = grad_y.reshape(-1, grad_y.shape[-1])
|
| 201 |
+
|
| 202 |
+
active_idx = torch.nonzero(active_rows, as_tuple=False).flatten()
|
| 203 |
+
|
| 204 |
+
grad_weight = torch.zeros_like(weight)
|
| 205 |
+
grad_bias = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if has_bias else None
|
| 206 |
+
|
| 207 |
+
if active_idx.numel() > 0:
|
| 208 |
+
gy_active = gy_flat[:, active_idx]
|
| 209 |
+
grad_weight[active_idx] = gy_active.transpose(0, 1) @ x_flat
|
| 210 |
+
if grad_bias is not None:
|
| 211 |
+
grad_bias[active_idx] = gy_active.sum(dim=0)
|
| 212 |
+
|
| 213 |
+
if sparse_dx:
|
| 214 |
+
grad_x_flat = gy_active @ weight[active_idx]
|
| 215 |
+
else:
|
| 216 |
+
grad_x_flat = gy_flat @ weight
|
| 217 |
+
else:
|
| 218 |
+
# This can happen when a global top-k mask selects no rows from a
|
| 219 |
+
# particular layer. Conservative full_dX still propagates through that
|
| 220 |
+
# layer; aggressive sparse_dX cuts it off for that layer.
|
| 221 |
+
if sparse_dx:
|
| 222 |
+
grad_x_flat = torch.zeros_like(x_flat)
|
| 223 |
+
else:
|
| 224 |
+
grad_x_flat = gy_flat @ weight
|
| 225 |
+
|
| 226 |
+
grad_x = grad_x_flat.reshape(x_shape)
|
| 227 |
+
return grad_x, grad_weight, grad_bias, None, None
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class SparseLinear(nn.Linear):
|
| 231 |
+
"""nn.Linear with an optional row-sparse backward pass."""
|
| 232 |
+
|
| 233 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
| 234 |
+
super().__init__(in_features, out_features, bias=bias)
|
| 235 |
+
self.sparse_enabled = False
|
| 236 |
+
self.sparse_dx = False
|
| 237 |
+
self.active_rows: Optional[torch.Tensor] = None
|
| 238 |
+
|
| 239 |
+
def set_sparse_backward(self, enabled: bool, active_rows: Optional[torch.Tensor], sparse_dx: bool) -> None:
|
| 240 |
+
self.sparse_enabled = bool(enabled)
|
| 241 |
+
self.sparse_dx = bool(sparse_dx)
|
| 242 |
+
self.active_rows = active_rows
|
| 243 |
+
|
| 244 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 245 |
+
if not self.sparse_enabled or self.active_rows is None:
|
| 246 |
+
return F.linear(x, self.weight, self.bias)
|
| 247 |
+
return MaskedLinearFunction.apply(x, self.weight, self.bias, self.active_rows, self.sparse_dx)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
# -----------------------------
|
| 251 |
+
# Mini GPT
|
| 252 |
+
# -----------------------------
|
| 253 |
+
|
| 254 |
+
class CausalSelfAttention(nn.Module):
|
| 255 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 256 |
+
super().__init__()
|
| 257 |
+
assert n_embd % n_head == 0
|
| 258 |
+
self.n_head = n_head
|
| 259 |
+
self.head_dim = n_embd // n_head
|
| 260 |
+
self.c_attn = SparseLinear(n_embd, 3 * n_embd)
|
| 261 |
+
self.c_proj = SparseLinear(n_embd, n_embd)
|
| 262 |
+
self.dropout = nn.Dropout(dropout)
|
| 263 |
+
self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
|
| 264 |
+
|
| 265 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 266 |
+
B, T, C = x.shape
|
| 267 |
+
qkv = self.c_attn(x)
|
| 268 |
+
q, k, v = qkv.split(C, dim=2)
|
| 269 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 270 |
+
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 271 |
+
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 272 |
+
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 273 |
+
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
|
| 274 |
+
att = F.softmax(att, dim=-1)
|
| 275 |
+
att = self.dropout(att)
|
| 276 |
+
y = att @ v
|
| 277 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 278 |
+
return self.c_proj(y)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class FeedForward(nn.Module):
|
| 282 |
+
def __init__(self, n_embd: int, dropout: float):
|
| 283 |
+
super().__init__()
|
| 284 |
+
self.c_fc = SparseLinear(n_embd, 4 * n_embd)
|
| 285 |
+
self.c_proj = SparseLinear(4 * n_embd, n_embd)
|
| 286 |
+
self.dropout = nn.Dropout(dropout)
|
| 287 |
+
|
| 288 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 289 |
+
return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class Block(nn.Module):
|
| 293 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 294 |
+
super().__init__()
|
| 295 |
+
self.ln1 = nn.LayerNorm(n_embd)
|
| 296 |
+
self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
|
| 297 |
+
self.ln2 = nn.LayerNorm(n_embd)
|
| 298 |
+
self.mlp = FeedForward(n_embd, dropout)
|
| 299 |
+
|
| 300 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 301 |
+
x = x + self.attn(self.ln1(x))
|
| 302 |
+
x = x + self.mlp(self.ln2(x))
|
| 303 |
+
return x
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class MiniGPT(nn.Module):
|
| 307 |
+
def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float):
|
| 308 |
+
super().__init__()
|
| 309 |
+
self.block_size = block_size
|
| 310 |
+
self.tok_emb = nn.Embedding(vocab_size, n_embd)
|
| 311 |
+
self.pos_emb = nn.Embedding(block_size, n_embd)
|
| 312 |
+
self.drop = nn.Dropout(dropout)
|
| 313 |
+
self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
|
| 314 |
+
self.ln_f = nn.LayerNorm(n_embd)
|
| 315 |
+
self.lm_head = SparseLinear(n_embd, vocab_size)
|
| 316 |
+
|
| 317 |
+
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
|
| 318 |
+
B, T = idx.shape
|
| 319 |
+
pos = torch.arange(T, device=idx.device)
|
| 320 |
+
x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
|
| 321 |
+
x = self.drop(x)
|
| 322 |
+
x = self.blocks(x)
|
| 323 |
+
x = self.ln_f(x)
|
| 324 |
+
logits = self.lm_head(x)
|
| 325 |
+
loss = None
|
| 326 |
+
if targets is not None:
|
| 327 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
| 328 |
+
return logits, loss
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def named_sparse_linear_modules(model: nn.Module) -> List[Tuple[str, SparseLinear]]:
|
| 332 |
+
return [(name, m) for name, m in model.named_modules() if isinstance(m, SparseLinear)]
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def parameter_fractions(model: nn.Module) -> Tuple[int, int, float]:
|
| 336 |
+
total = sum(p.numel() for p in model.parameters())
|
| 337 |
+
linear = 0
|
| 338 |
+
for _, m in named_sparse_linear_modules(model):
|
| 339 |
+
linear += m.weight.numel()
|
| 340 |
+
if m.bias is not None:
|
| 341 |
+
linear += m.bias.numel()
|
| 342 |
+
return total, linear, linear / max(1, total)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def configure_sparse_linears(
|
| 346 |
+
model: nn.Module,
|
| 347 |
+
masker: Optional["RowMasker"],
|
| 348 |
+
enabled: bool,
|
| 349 |
+
backward_mode: Optional[str],
|
| 350 |
+
) -> None:
|
| 351 |
+
sparse_dx = backward_mode == "sparse_dW_sparse_dX"
|
| 352 |
+
for _, m in named_sparse_linear_modules(model):
|
| 353 |
+
active = masker.row_mask_for(m) if masker is not None else None
|
| 354 |
+
m.set_sparse_backward(enabled=enabled, active_rows=active, sparse_dx=sparse_dx)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
# -----------------------------
|
| 358 |
+
# Mask selector
|
| 359 |
+
# -----------------------------
|
| 360 |
+
|
| 361 |
+
class RowMasker:
|
| 362 |
+
def __init__(
|
| 363 |
+
self,
|
| 364 |
+
model: nn.Module,
|
| 365 |
+
policy: Policy,
|
| 366 |
+
active_fraction: float,
|
| 367 |
+
explore_fraction: float,
|
| 368 |
+
mass_beta: float,
|
| 369 |
+
unobserved_decay: float,
|
| 370 |
+
warmup_steps: int,
|
| 371 |
+
ucb_alpha: float,
|
| 372 |
+
mass_init: float,
|
| 373 |
+
device: str,
|
| 374 |
+
):
|
| 375 |
+
self.model = model
|
| 376 |
+
self.policy = policy
|
| 377 |
+
self.active_fraction = active_fraction
|
| 378 |
+
self.explore_fraction = explore_fraction
|
| 379 |
+
self.mass_beta = mass_beta
|
| 380 |
+
self.unobserved_decay = unobserved_decay
|
| 381 |
+
self.warmup_steps = warmup_steps
|
| 382 |
+
self.ucb_alpha = ucb_alpha
|
| 383 |
+
self.mass_init = mass_init
|
| 384 |
+
self.device = device
|
| 385 |
+
self.step_index = 0
|
| 386 |
+
|
| 387 |
+
self.linear_modules = [m for _, m in named_sparse_linear_modules(model)]
|
| 388 |
+
self.module_to_ids: Dict[SparseLinear, torch.Tensor] = {}
|
| 389 |
+
ids = []
|
| 390 |
+
offset = 0
|
| 391 |
+
for m in self.linear_modules:
|
| 392 |
+
n = m.weight.shape[0]
|
| 393 |
+
block_ids = torch.arange(offset, offset + n, device=device)
|
| 394 |
+
self.module_to_ids[m] = block_ids
|
| 395 |
+
ids.append(block_ids)
|
| 396 |
+
offset += n
|
| 397 |
+
self.n_blocks = offset
|
| 398 |
+
|
| 399 |
+
self.predicted_mass = torch.full((self.n_blocks,), mass_init, device=device)
|
| 400 |
+
self.last_full_mass = torch.full((self.n_blocks,), mass_init, device=device)
|
| 401 |
+
self.observed_count = torch.zeros(self.n_blocks, device=device)
|
| 402 |
+
self.global_mass_ema = torch.tensor(max(mass_init, 1e-6), device=device)
|
| 403 |
+
|
| 404 |
+
self.prev_active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device)
|
| 405 |
+
self.active = torch.zeros(self.n_blocks, dtype=torch.bool, device=device)
|
| 406 |
+
self.row_masks: Dict[SparseLinear, torch.Tensor] = {
|
| 407 |
+
m: torch.zeros(m.weight.shape[0], dtype=torch.bool, device=device) for m in self.linear_modules
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
def _topk_mask(self, values: torch.Tensor, fraction: float) -> torch.Tensor:
|
| 411 |
+
k = max(1, int(fraction * values.numel()))
|
| 412 |
+
mask = torch.zeros_like(values, dtype=torch.bool)
|
| 413 |
+
noisy = values + 1e-9 * torch.rand_like(values)
|
| 414 |
+
mask[torch.topk(noisy, k=k).indices] = True
|
| 415 |
+
return mask
|
| 416 |
+
|
| 417 |
+
@staticmethod
|
| 418 |
+
def _jaccard(a: torch.Tensor, b: torch.Tensor) -> float:
|
| 419 |
+
inter = (a & b).sum().float()
|
| 420 |
+
union = (a | b).sum().float()
|
| 421 |
+
return float((inter / torch.clamp(union, min=1.0)).item())
|
| 422 |
+
|
| 423 |
+
def _set_active(self, active: torch.Tensor) -> None:
|
| 424 |
+
self.active = active
|
| 425 |
+
self.row_masks = {}
|
| 426 |
+
for m, ids in self.module_to_ids.items():
|
| 427 |
+
self.row_masks[m] = active[ids]
|
| 428 |
+
|
| 429 |
+
def _sample_exploit_explore(self, scores: torch.Tensor) -> torch.Tensor:
|
| 430 |
+
n = self.n_blocks
|
| 431 |
+
k_total = max(1, int(self.active_fraction * n))
|
| 432 |
+
k_explore = min(k_total, max(0, int(self.explore_fraction * k_total)))
|
| 433 |
+
k_exploit = k_total - k_explore
|
| 434 |
+
active = torch.zeros(n, dtype=torch.bool, device=self.device)
|
| 435 |
+
|
| 436 |
+
if k_exploit > 0:
|
| 437 |
+
active[torch.topk(scores + 1e-9 * torch.rand_like(scores), k=k_exploit).indices] = True
|
| 438 |
+
if k_explore > 0:
|
| 439 |
+
remaining = torch.nonzero(~active, as_tuple=False).flatten()
|
| 440 |
+
pick = remaining[torch.randperm(remaining.numel(), device=self.device)[:k_explore]]
|
| 441 |
+
active[pick] = True
|
| 442 |
+
return active
|
| 443 |
+
|
| 444 |
+
def choose_pre_backward(self, step: int) -> None:
|
| 445 |
+
self.step_index = step
|
| 446 |
+
if step < self.warmup_steps:
|
| 447 |
+
self._set_active(torch.ones(self.n_blocks, dtype=torch.bool, device=self.device))
|
| 448 |
+
return
|
| 449 |
+
|
| 450 |
+
if self.policy == "oracle_current":
|
| 451 |
+
# Oracle cannot choose until the dense audit gradient is known.
|
| 452 |
+
self._set_active(torch.zeros(self.n_blocks, dtype=torch.bool, device=self.device))
|
| 453 |
+
return
|
| 454 |
+
|
| 455 |
+
if self.policy == "random":
|
| 456 |
+
self._set_active(self._sample_exploit_explore(torch.rand(self.n_blocks, device=self.device)))
|
| 457 |
+
return
|
| 458 |
+
|
| 459 |
+
if self.policy == "stale_current":
|
| 460 |
+
self._set_active(self._topk_mask(self.last_full_mass, self.active_fraction))
|
| 461 |
+
return
|
| 462 |
+
|
| 463 |
+
if self.policy == "predicted_magnitude":
|
| 464 |
+
self._set_active(self._sample_exploit_explore(self.predicted_mass))
|
| 465 |
+
return
|
| 466 |
+
|
| 467 |
+
if self.policy == "ucb_magnitude":
|
| 468 |
+
t = max(1, step - self.warmup_steps + 1)
|
| 469 |
+
log_term = torch.log(torch.tensor(float(t + 2), device=self.device))
|
| 470 |
+
bonus_scale = torch.clamp(self.global_mass_ema, min=1e-8)
|
| 471 |
+
bonus = self.ucb_alpha * bonus_scale * torch.sqrt(log_term / (self.observed_count + 1.0))
|
| 472 |
+
self._set_active(self._sample_exploit_explore(self.predicted_mass + bonus))
|
| 473 |
+
return
|
| 474 |
+
|
| 475 |
+
raise ValueError(f"Unknown policy: {self.policy}")
|
| 476 |
+
|
| 477 |
+
@torch.no_grad()
|
| 478 |
+
def current_gradient_mass_from_grads(self) -> torch.Tensor:
|
| 479 |
+
mass = torch.zeros(self.n_blocks, device=self.device)
|
| 480 |
+
for m, ids in self.module_to_ids.items():
|
| 481 |
+
if m.weight.grad is None:
|
| 482 |
+
continue
|
| 483 |
+
row_sq = m.weight.grad.square().sum(dim=1)
|
| 484 |
+
if m.bias is not None and m.bias.grad is not None:
|
| 485 |
+
row_sq = row_sq + m.bias.grad.square()
|
| 486 |
+
mass[ids] = torch.sqrt(row_sq + 1e-30)
|
| 487 |
+
return mass
|
| 488 |
+
|
| 489 |
+
@torch.no_grad()
|
| 490 |
+
@torch.no_grad()
|
| 491 |
+
def update_predictor_from_observed_mass(self, mass: torch.Tensor, observed: Optional[torch.Tensor] = None) -> Dict[str, float]:
|
| 492 |
+
"""Update EMA statistics only for observed rows.
|
| 493 |
+
|
| 494 |
+
After warmup, sparse backward only gives trustworthy gradients for active
|
| 495 |
+
rows, so only those rows are allowed to update predicted_mass.
|
| 496 |
+
"""
|
| 497 |
+
if observed is None:
|
| 498 |
+
observed = self.active
|
| 499 |
+
|
| 500 |
+
new_active = observed & (self.observed_count == 0)
|
| 501 |
+
self.predicted_mass.mul_(self.unobserved_decay)
|
| 502 |
+
|
| 503 |
+
if bool(observed.any().item()):
|
| 504 |
+
obs_mass = mass[observed]
|
| 505 |
+
first_seen = self.observed_count[observed] == 0
|
| 506 |
+
ema_mass = self.mass_beta * self.predicted_mass[observed] + (1.0 - self.mass_beta) * obs_mass
|
| 507 |
+
self.predicted_mass[observed] = torch.where(first_seen, obs_mass, ema_mass)
|
| 508 |
+
self.observed_count[observed] += 1.0
|
| 509 |
+
self.global_mass_ema = self.mass_beta * self.global_mass_ema + (1.0 - self.mass_beta) * obs_mass.mean()
|
| 510 |
+
|
| 511 |
+
stability = self._jaccard(self.active, self.prev_active)
|
| 512 |
+
self.prev_active = self.active.clone()
|
| 513 |
+
|
| 514 |
+
return {
|
| 515 |
+
"stability": stability,
|
| 516 |
+
"active_fraction_real": float(self.active.float().mean().item()),
|
| 517 |
+
"coverage": float((self.observed_count > 0).float().mean().item()),
|
| 518 |
+
"avg_obs_count": float(self.observed_count.mean().item()),
|
| 519 |
+
"new_active_fraction": float(new_active.float().mean().item()),
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
@torch.no_grad()
|
| 523 |
+
def audit_metrics_from_mass(self, mass: torch.Tensor) -> Dict[str, float]:
|
| 524 |
+
"""Compute dense-audit metrics without updating the practical selector."""
|
| 525 |
+
active = self.active
|
| 526 |
+
true_sq = mass.square().sum()
|
| 527 |
+
approx_sq = mass[active].square().sum()
|
| 528 |
+
cosine = float((torch.sqrt(approx_sq + 1e-30) / torch.sqrt(true_sq + 1e-30)).item())
|
| 529 |
+
|
| 530 |
+
oracle_mask = self._topk_mask(mass, self.active_fraction)
|
| 531 |
+
jacc = self._jaccard(active, oracle_mask)
|
| 532 |
+
|
| 533 |
+
k20 = max(1, int(0.2 * self.n_blocks))
|
| 534 |
+
sorted_mass = torch.sort(mass, descending=True).values
|
| 535 |
+
top20_mass = float((sorted_mass[:k20].sum() / (sorted_mass.sum() + 1e-12)).item())
|
| 536 |
+
|
| 537 |
+
return {
|
| 538 |
+
"cosine": cosine,
|
| 539 |
+
"norm_ratio": cosine,
|
| 540 |
+
"top20_mass": top20_mass,
|
| 541 |
+
"jacc_oracle": jacc,
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
def audit_and_update_from_mass(self, step: int, mass: torch.Tensor) -> Dict[str, float]:
|
| 545 |
+
if step < self.warmup_steps:
|
| 546 |
+
active = torch.ones(self.n_blocks, dtype=torch.bool, device=self.device)
|
| 547 |
+
self._set_active(active)
|
| 548 |
+
elif self.policy == "oracle_current":
|
| 549 |
+
active = self._topk_mask(mass, self.active_fraction)
|
| 550 |
+
self._set_active(active)
|
| 551 |
+
else:
|
| 552 |
+
active = self.active
|
| 553 |
+
|
| 554 |
+
true_sq = mass.square().sum()
|
| 555 |
+
approx_sq = mass[active].square().sum()
|
| 556 |
+
cosine = float((torch.sqrt(approx_sq + 1e-30) / torch.sqrt(true_sq + 1e-30)).item())
|
| 557 |
+
|
| 558 |
+
oracle_mask = self._topk_mask(mass, self.active_fraction)
|
| 559 |
+
jacc = self._jaccard(active, oracle_mask)
|
| 560 |
+
stability = self._jaccard(active, self.prev_active)
|
| 561 |
+
self.prev_active = active.clone()
|
| 562 |
+
|
| 563 |
+
k20 = max(1, int(0.2 * self.n_blocks))
|
| 564 |
+
sorted_mass = torch.sort(mass, descending=True).values
|
| 565 |
+
top20_mass = float((sorted_mass[:k20].sum() / (sorted_mass.sum() + 1e-12)).item())
|
| 566 |
+
|
| 567 |
+
new_active = active & (self.observed_count == 0)
|
| 568 |
+
|
| 569 |
+
# Practical rule: update predicted statistics only for active/observed rows.
|
| 570 |
+
self.predicted_mass.mul_(self.unobserved_decay)
|
| 571 |
+
observed = active
|
| 572 |
+
if bool(observed.any().item()):
|
| 573 |
+
obs_mass = mass[observed]
|
| 574 |
+
first_seen = self.observed_count[observed] == 0
|
| 575 |
+
ema_mass = self.mass_beta * self.predicted_mass[observed] + (1.0 - self.mass_beta) * obs_mass
|
| 576 |
+
self.predicted_mass[observed] = torch.where(first_seen, obs_mass, ema_mass)
|
| 577 |
+
self.observed_count[observed] += 1.0
|
| 578 |
+
self.global_mass_ema = self.mass_beta * self.global_mass_ema + (1.0 - self.mass_beta) * obs_mass.mean()
|
| 579 |
+
|
| 580 |
+
# Dense audit signal; only stale_current is allowed to use this for selection.
|
| 581 |
+
self.last_full_mass = mass.detach().clone()
|
| 582 |
+
|
| 583 |
+
return {
|
| 584 |
+
"cosine": cosine,
|
| 585 |
+
"norm_ratio": cosine,
|
| 586 |
+
"top20_mass": top20_mass,
|
| 587 |
+
"jacc_oracle": jacc,
|
| 588 |
+
"stability": stability,
|
| 589 |
+
"active_fraction_real": float(active.float().mean().item()),
|
| 590 |
+
"coverage": float((self.observed_count > 0).float().mean().item()),
|
| 591 |
+
"avg_obs_count": float(self.observed_count.mean().item()),
|
| 592 |
+
"new_active_fraction": float(new_active.float().mean().item()),
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
def row_mask_for(self, module: SparseLinear) -> Optional[torch.Tensor]:
|
| 596 |
+
return self.row_masks.get(module)
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
# -----------------------------
|
| 600 |
+
# Masked Adam
|
| 601 |
+
# -----------------------------
|
| 602 |
+
|
| 603 |
+
class MaskedAdam:
|
| 604 |
+
def __init__(
|
| 605 |
+
self,
|
| 606 |
+
model: nn.Module,
|
| 607 |
+
masker: Optional[RowMasker],
|
| 608 |
+
lr: float,
|
| 609 |
+
betas=(0.9, 0.95),
|
| 610 |
+
eps=1e-8,
|
| 611 |
+
weight_decay=0.0,
|
| 612 |
+
freeze_non_linear_when_sparse: bool = False,
|
| 613 |
+
):
|
| 614 |
+
self.model = model
|
| 615 |
+
self.masker = masker
|
| 616 |
+
self.lr = lr
|
| 617 |
+
self.beta1, self.beta2 = betas
|
| 618 |
+
self.eps = eps
|
| 619 |
+
self.weight_decay = weight_decay
|
| 620 |
+
self.freeze_non_linear_when_sparse = freeze_non_linear_when_sparse
|
| 621 |
+
self.state: Dict[nn.Parameter, Dict[str, torch.Tensor]] = {}
|
| 622 |
+
self.linear_param: Dict[nn.Parameter, Tuple[SparseLinear, str]] = {}
|
| 623 |
+
for _, m in named_sparse_linear_modules(model):
|
| 624 |
+
self.linear_param[m.weight] = (m, "weight")
|
| 625 |
+
if m.bias is not None:
|
| 626 |
+
self.linear_param[m.bias] = (m, "bias")
|
| 627 |
+
|
| 628 |
+
def zero_grad(self) -> None:
|
| 629 |
+
for p in self.model.parameters():
|
| 630 |
+
p.grad = None
|
| 631 |
+
|
| 632 |
+
@torch.no_grad()
|
| 633 |
+
def step(self) -> None:
|
| 634 |
+
for p in self.model.parameters():
|
| 635 |
+
if p.grad is None:
|
| 636 |
+
continue
|
| 637 |
+
if self.masker is not None and self.freeze_non_linear_when_sparse and p not in self.linear_param:
|
| 638 |
+
continue
|
| 639 |
+
|
| 640 |
+
if p not in self.state:
|
| 641 |
+
self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
|
| 642 |
+
m = self.state[p]["m"]
|
| 643 |
+
v = self.state[p]["v"]
|
| 644 |
+
g = p.grad
|
| 645 |
+
if self.weight_decay:
|
| 646 |
+
g = g.add(p, alpha=self.weight_decay)
|
| 647 |
+
|
| 648 |
+
row_mask = None
|
| 649 |
+
if self.masker is not None and p in self.linear_param:
|
| 650 |
+
module, kind = self.linear_param[p]
|
| 651 |
+
base = self.masker.row_mask_for(module)
|
| 652 |
+
if base is not None:
|
| 653 |
+
row_mask = base.view(-1, *([1] * (p.ndim - 1))) if kind == "weight" else base
|
| 654 |
+
|
| 655 |
+
if row_mask is None:
|
| 656 |
+
m.mul_(self.beta1).add_(g, alpha=1.0 - self.beta1)
|
| 657 |
+
v.mul_(self.beta2).addcmul_(g, g, value=1.0 - self.beta2)
|
| 658 |
+
p.add_(m / (torch.sqrt(v) + self.eps), alpha=-self.lr)
|
| 659 |
+
else:
|
| 660 |
+
# MPS can mis-handle expanded boolean masks for row-wise assignment
|
| 661 |
+
# (e.g. reporting nonsense out-of-bounds indices). Use explicit
|
| 662 |
+
# row indices and index_copy_ instead. This also avoids materializing
|
| 663 |
+
# a full expanded mask for weight matrices.
|
| 664 |
+
active_rows = row_mask.reshape(-1).nonzero(as_tuple=False).flatten()
|
| 665 |
+
if active_rows.numel() == 0:
|
| 666 |
+
continue
|
| 667 |
+
|
| 668 |
+
m_rows = m.index_select(0, active_rows)
|
| 669 |
+
v_rows = v.index_select(0, active_rows)
|
| 670 |
+
g_rows = g.index_select(0, active_rows)
|
| 671 |
+
|
| 672 |
+
new_m_rows = self.beta1 * m_rows + (1.0 - self.beta1) * g_rows
|
| 673 |
+
new_v_rows = self.beta2 * v_rows + (1.0 - self.beta2) * g_rows * g_rows
|
| 674 |
+
update_rows = new_m_rows / (torch.sqrt(new_v_rows) + self.eps)
|
| 675 |
+
p_rows = p.index_select(0, active_rows) - self.lr * update_rows
|
| 676 |
+
|
| 677 |
+
m.index_copy_(0, active_rows, new_m_rows)
|
| 678 |
+
v.index_copy_(0, active_rows, new_v_rows)
|
| 679 |
+
p.index_copy_(0, active_rows, p_rows)
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
# -----------------------------
|
| 683 |
+
# Training utilities
|
| 684 |
+
# -----------------------------
|
| 685 |
+
|
| 686 |
+
@torch.no_grad()
|
| 687 |
+
def estimate_loss(model: nn.Module, corpus: CharCorpus, batch_size: int, eval_iters: int, seed: int) -> Dict[str, float]:
|
| 688 |
+
model.eval()
|
| 689 |
+
configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
|
| 690 |
+
out = {}
|
| 691 |
+
for split in ["train", "val"]:
|
| 692 |
+
losses = []
|
| 693 |
+
gen = make_cpu_generator(seed + (0 if split == "train" else 100000))
|
| 694 |
+
for _ in range(eval_iters):
|
| 695 |
+
x, y = corpus.get_batch(split, batch_size, generator=gen)
|
| 696 |
+
_, loss = model(x, y)
|
| 697 |
+
losses.append(float(loss.item()))
|
| 698 |
+
out[split] = sum(losses) / len(losses)
|
| 699 |
+
model.train()
|
| 700 |
+
return out
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
def dense_audit_pass(model: nn.Module, corpus_batch: Tuple[torch.Tensor, torch.Tensor], opt: MaskedAdam, masker: RowMasker) -> torch.Tensor:
|
| 704 |
+
x, y = corpus_batch
|
| 705 |
+
configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
|
| 706 |
+
opt.zero_grad()
|
| 707 |
+
_, audit_loss = model(x, y)
|
| 708 |
+
audit_loss.backward()
|
| 709 |
+
mass = masker.current_gradient_mass_from_grads()
|
| 710 |
+
opt.zero_grad()
|
| 711 |
+
return mass
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
def sparse_training_backward(
|
| 715 |
+
model: nn.Module,
|
| 716 |
+
corpus_batch: Tuple[torch.Tensor, torch.Tensor],
|
| 717 |
+
opt: MaskedAdam,
|
| 718 |
+
masker: Optional[RowMasker],
|
| 719 |
+
backward_mode: Optional[BackwardMode],
|
| 720 |
+
) -> float:
|
| 721 |
+
x, y = corpus_batch
|
| 722 |
+
opt.zero_grad()
|
| 723 |
+
|
| 724 |
+
if masker is None or backward_mode is None or backward_mode == "masked_optimizer":
|
| 725 |
+
configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
|
| 726 |
+
else:
|
| 727 |
+
configure_sparse_linears(model, masker=masker, enabled=True, backward_mode=backward_mode)
|
| 728 |
+
|
| 729 |
+
_, loss = model(x, y)
|
| 730 |
+
loss.backward()
|
| 731 |
+
configure_sparse_linears(model, masker=None, enabled=False, backward_mode=None)
|
| 732 |
+
return float(loss.item())
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
def train_run(
|
| 736 |
+
corpus: CharCorpus,
|
| 737 |
+
args: argparse.Namespace,
|
| 738 |
+
policy: Optional[Policy],
|
| 739 |
+
backward_mode: Optional[BackwardMode],
|
| 740 |
+
active_fraction: float,
|
| 741 |
+
warmup_steps: int,
|
| 742 |
+
explore_fraction: float,
|
| 743 |
+
seed_offset: int,
|
| 744 |
+
) -> Dict[str, float | str]:
|
| 745 |
+
# Same model initialization and same minibatch sequence for every run by default.
|
| 746 |
+
set_seed(args.seed + (seed_offset if args.unpaired_seeds else 0))
|
| 747 |
+
data_gen = make_cpu_generator(args.seed + 12345)
|
| 748 |
+
|
| 749 |
+
dev = corpus.device
|
| 750 |
+
model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev)
|
| 751 |
+
|
| 752 |
+
masker = None
|
| 753 |
+
if policy is not None:
|
| 754 |
+
masker = RowMasker(
|
| 755 |
+
model=model,
|
| 756 |
+
policy=policy,
|
| 757 |
+
active_fraction=active_fraction,
|
| 758 |
+
explore_fraction=explore_fraction,
|
| 759 |
+
mass_beta=args.mass_beta,
|
| 760 |
+
unobserved_decay=args.unobserved_decay,
|
| 761 |
+
warmup_steps=warmup_steps,
|
| 762 |
+
ucb_alpha=args.ucb_alpha,
|
| 763 |
+
mass_init=args.mass_init,
|
| 764 |
+
device=dev,
|
| 765 |
+
)
|
| 766 |
+
opt = MaskedAdam(
|
| 767 |
+
model,
|
| 768 |
+
masker,
|
| 769 |
+
lr=args.lr,
|
| 770 |
+
weight_decay=args.weight_decay,
|
| 771 |
+
freeze_non_linear_when_sparse=args.freeze_non_linear_when_sparse,
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
sums = {
|
| 775 |
+
"cosine": 0.0,
|
| 776 |
+
"norm_ratio": 0.0,
|
| 777 |
+
"top20_mass": 0.0,
|
| 778 |
+
"jacc_oracle": 0.0,
|
| 779 |
+
"stability": 0.0,
|
| 780 |
+
"active_fraction_real": 0.0,
|
| 781 |
+
"coverage": 0.0,
|
| 782 |
+
"avg_obs_count": 0.0,
|
| 783 |
+
"new_active_fraction": 0.0,
|
| 784 |
+
}
|
| 785 |
+
counts = {k: 0 for k in sums}
|
| 786 |
+
|
| 787 |
+
def add_metrics(metrics: Dict[str, float]) -> None:
|
| 788 |
+
for k, v in metrics.items():
|
| 789 |
+
if k in sums:
|
| 790 |
+
sums[k] += float(v)
|
| 791 |
+
counts[k] += 1
|
| 792 |
+
|
| 793 |
+
for step in range(args.steps):
|
| 794 |
+
batch = corpus.get_batch("train", args.batch_size, generator=data_gen)
|
| 795 |
+
|
| 796 |
+
if masker is None:
|
| 797 |
+
loss_value = sparse_training_backward(model, batch, opt, masker=None, backward_mode=None)
|
| 798 |
+
opt.step()
|
| 799 |
+
else:
|
| 800 |
+
if step < warmup_steps:
|
| 801 |
+
# Dense bootstrap. Every row is active and every row updates the predictor.
|
| 802 |
+
masker._set_active(torch.ones(masker.n_blocks, dtype=torch.bool, device=dev))
|
| 803 |
+
loss_value = sparse_training_backward(model, batch, opt, masker=masker, backward_mode="masked_optimizer")
|
| 804 |
+
full_mass = masker.current_gradient_mass_from_grads()
|
| 805 |
+
masker.last_full_mass = full_mass.detach().clone()
|
| 806 |
+
add_metrics(masker.audit_metrics_from_mass(full_mass))
|
| 807 |
+
add_metrics(masker.update_predictor_from_observed_mass(full_mass, observed=masker.active))
|
| 808 |
+
opt.step()
|
| 809 |
+
else:
|
| 810 |
+
masker.choose_pre_backward(step)
|
| 811 |
+
|
| 812 |
+
if policy == "oracle_current":
|
| 813 |
+
# Explicit upper bound. Oracle necessarily computes dense gradients to choose rows.
|
| 814 |
+
full_mass = dense_audit_pass(model, batch, opt, masker)
|
| 815 |
+
masker._set_active(masker._topk_mask(full_mass, active_fraction))
|
| 816 |
+
masker.last_full_mass = full_mass.detach().clone()
|
| 817 |
+
add_metrics(masker.audit_metrics_from_mass(full_mass))
|
| 818 |
+
elif args.audit_every > 0 and ((step - warmup_steps) % args.audit_every == 0):
|
| 819 |
+
# Measurement only. Do not update predicted_magnitude/ucb/random with this dense mass.
|
| 820 |
+
full_mass = dense_audit_pass(model, batch, opt, masker)
|
| 821 |
+
add_metrics(masker.audit_metrics_from_mass(full_mass))
|
| 822 |
+
if policy == "stale_current":
|
| 823 |
+
masker.last_full_mass = full_mass.detach().clone()
|
| 824 |
+
|
| 825 |
+
loss_value = sparse_training_backward(model, batch, opt, masker=masker, backward_mode=backward_mode)
|
| 826 |
+
|
| 827 |
+
# Practical selector update: only active rows were observed by the training backward pass.
|
| 828 |
+
observed_mass = masker.current_gradient_mass_from_grads()
|
| 829 |
+
add_metrics(masker.update_predictor_from_observed_mass(observed_mass, observed=masker.active))
|
| 830 |
+
opt.step()
|
| 831 |
+
|
| 832 |
+
if args.verbose and (step % args.eval_interval == 0 or step == args.steps - 1):
|
| 833 |
+
losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters, seed=args.seed + 555)
|
| 834 |
+
name = "dense" if policy is None else f"{policy}/{backward_mode}"
|
| 835 |
+
print(
|
| 836 |
+
f"{name:38s} step={step:5d} warm={warmup_steps:4d} explore={explore_fraction:.2f} "
|
| 837 |
+
f"loss={loss_value:.4f} train={losses['train']:.4f} val={losses['val']:.4f}"
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
losses = estimate_loss(model, corpus, args.batch_size, args.eval_iters, seed=args.seed + 999)
|
| 841 |
+
row: Dict[str, float | str] = {
|
| 842 |
+
"run": "dense_baseline" if policy is None else policy,
|
| 843 |
+
"mode": "dense" if backward_mode is None else backward_mode,
|
| 844 |
+
"target_active": 1.0 if policy is None else active_fraction,
|
| 845 |
+
"warmup": warmup_steps,
|
| 846 |
+
"explore": explore_fraction if policy is not None else 0.0,
|
| 847 |
+
"train_loss": losses["train"],
|
| 848 |
+
"val_loss": losses["val"],
|
| 849 |
+
}
|
| 850 |
+
if masker is None:
|
| 851 |
+
row.update({
|
| 852 |
+
"cosine": float("nan"),
|
| 853 |
+
"norm_ratio": float("nan"),
|
| 854 |
+
"top20_mass": float("nan"),
|
| 855 |
+
"jacc_oracle": float("nan"),
|
| 856 |
+
"stability": float("nan"),
|
| 857 |
+
"active_fraction_real": 1.0,
|
| 858 |
+
"coverage": float("nan"),
|
| 859 |
+
"avg_obs_count": float("nan"),
|
| 860 |
+
"new_active_fraction": float("nan"),
|
| 861 |
+
})
|
| 862 |
+
else:
|
| 863 |
+
for k in sums:
|
| 864 |
+
row[k] = (sums[k] / counts[k]) if counts[k] > 0 else float("nan")
|
| 865 |
+
return row
|
| 866 |
+
|
| 867 |
+
def print_summary(rows: List[Dict[str, float | str]]) -> None:
|
| 868 |
+
print("\nSummary")
|
| 869 |
+
header = (
|
| 870 |
+
f"{'run':>22s} {'mode':>19s} {'target':>7s} {'actual':>7s} {'warm':>5s} {'expl':>5s} "
|
| 871 |
+
f"{'val':>8s} {'train':>8s} {'cos':>7s} {'top20':>7s} {'jacc':>7s} "
|
| 872 |
+
f"{'stable':>7s} {'cover':>7s} {'new':>7s}"
|
| 873 |
+
)
|
| 874 |
+
print(header)
|
| 875 |
+
print("-" * len(header))
|
| 876 |
+
for r in rows:
|
| 877 |
+
print(
|
| 878 |
+
f"{str(r['run']):>22s} "
|
| 879 |
+
f"{str(r['mode']):>19s} "
|
| 880 |
+
f"{float(r['target_active']):7.3f} "
|
| 881 |
+
f"{float(r['active_fraction_real']):7.3f} "
|
| 882 |
+
f"{int(float(r['warmup'])):5d} "
|
| 883 |
+
f"{float(r['explore']):5.2f} "
|
| 884 |
+
f"{float(r['val_loss']):8.4f} "
|
| 885 |
+
f"{float(r['train_loss']):8.4f} "
|
| 886 |
+
f"{float(r['cosine']):7.3f} "
|
| 887 |
+
f"{float(r['top20_mass']):7.3f} "
|
| 888 |
+
f"{float(r['jacc_oracle']):7.3f} "
|
| 889 |
+
f"{float(r['stability']):7.3f} "
|
| 890 |
+
f"{float(r['coverage']):7.3f} "
|
| 891 |
+
f"{float(r['new_active_fraction']):7.3f}"
|
| 892 |
+
)
|
| 893 |
+
|
| 894 |
+
|
| 895 |
+
def parse_args() -> argparse.Namespace:
|
| 896 |
+
p = argparse.ArgumentParser()
|
| 897 |
+
p.add_argument("--text_path", type=str, default=None)
|
| 898 |
+
p.add_argument("--synthetic_sentences", type=int, default=12000)
|
| 899 |
+
p.add_argument("--steps", type=int, default=1000)
|
| 900 |
+
p.add_argument("--quick", action="store_true")
|
| 901 |
+
p.add_argument("--batch_size", type=int, default=32)
|
| 902 |
+
p.add_argument("--block_size", type=int, default=64)
|
| 903 |
+
p.add_argument("--n_layer", type=int, default=2)
|
| 904 |
+
p.add_argument("--n_head", type=int, default=4)
|
| 905 |
+
p.add_argument("--n_embd", type=int, default=64)
|
| 906 |
+
p.add_argument("--dropout", type=float, default=0.0)
|
| 907 |
+
p.add_argument("--lr", type=float, default=3e-4)
|
| 908 |
+
p.add_argument("--weight_decay", type=float, default=0.0)
|
| 909 |
+
p.add_argument("--active_fractions", type=float, nargs="+", default=[0.05, 0.02])
|
| 910 |
+
p.add_argument("--policies", type=str, nargs="+", default=["oracle_current", "predicted_magnitude", "random"])
|
| 911 |
+
p.add_argument(
|
| 912 |
+
"--backward_modes",
|
| 913 |
+
type=str,
|
| 914 |
+
nargs="+",
|
| 915 |
+
default=["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"],
|
| 916 |
+
)
|
| 917 |
+
p.add_argument("--explore_fractions", type=float, nargs="+", default=[0.0])
|
| 918 |
+
p.add_argument("--warmup_steps_list", type=int, nargs="+", default=[5])
|
| 919 |
+
p.add_argument("--mass_beta", type=float, default=0.95)
|
| 920 |
+
p.add_argument("--unobserved_decay", type=float, default=1.0)
|
| 921 |
+
p.add_argument("--mass_init", type=float, default=0.0)
|
| 922 |
+
p.add_argument("--ucb_alpha", type=float, default=1.0)
|
| 923 |
+
p.add_argument("--freeze_non_linear_when_sparse", action="store_true")
|
| 924 |
+
p.add_argument("--eval_interval", type=int, default=200)
|
| 925 |
+
p.add_argument("--eval_iters", type=int, default=20)
|
| 926 |
+
p.add_argument("--seed", type=int, default=7)
|
| 927 |
+
p.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "cuda", "mps"])
|
| 928 |
+
p.add_argument("--audit_every", type=int, default=0, help="Dense audit interval after warmup. 0 disables audits except oracle_current.")
|
| 929 |
+
p.add_argument("--unpaired_seeds", action="store_true", help="Use different init seeds per run instead of paired seeds.")
|
| 930 |
+
p.add_argument("--verbose", action="store_true")
|
| 931 |
+
return p.parse_args()
|
| 932 |
+
|
| 933 |
+
|
| 934 |
+
def main() -> None:
|
| 935 |
+
args = parse_args()
|
| 936 |
+
if args.quick:
|
| 937 |
+
args.steps = 40
|
| 938 |
+
args.eval_iters = 2
|
| 939 |
+
args.batch_size = 8
|
| 940 |
+
args.block_size = 32
|
| 941 |
+
args.n_layer = 1
|
| 942 |
+
args.n_embd = 32
|
| 943 |
+
args.n_head = 4
|
| 944 |
+
args.synthetic_sentences = 1200
|
| 945 |
+
args.active_fractions = [0.05]
|
| 946 |
+
args.policies = ["predicted_magnitude", "random"]
|
| 947 |
+
args.backward_modes = ["masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
|
| 948 |
+
args.explore_fractions = [0.0]
|
| 949 |
+
args.warmup_steps_list = [5]
|
| 950 |
+
args.audit_every = 10
|
| 951 |
+
|
| 952 |
+
valid_policies = {"predicted_magnitude", "ucb_magnitude", "oracle_current", "stale_current", "random"}
|
| 953 |
+
valid_modes = {"masked_optimizer", "sparse_dW_full_dX", "sparse_dW_sparse_dX"}
|
| 954 |
+
for pol in args.policies:
|
| 955 |
+
if pol not in valid_policies:
|
| 956 |
+
raise ValueError(f"Unknown policy {pol!r}. Valid policies: {sorted(valid_policies)}")
|
| 957 |
+
for mode in args.backward_modes:
|
| 958 |
+
if mode not in valid_modes:
|
| 959 |
+
raise ValueError(f"Unknown backward mode {mode!r}. Valid modes: {sorted(valid_modes)}")
|
| 960 |
+
|
| 961 |
+
set_seed(args.seed)
|
| 962 |
+
dev = args.device if args.device != "auto" else default_device()
|
| 963 |
+
print(f"device={dev}")
|
| 964 |
+
corpus = CharCorpus(load_text(args), args.block_size, dev)
|
| 965 |
+
print(f"vocab_size={corpus.vocab_size} train_tokens={len(corpus.train_data)} val_tokens={len(corpus.val_data)}")
|
| 966 |
+
print(f"policies={args.policies}")
|
| 967 |
+
print(f"backward_modes={args.backward_modes}")
|
| 968 |
+
print(f"active_fractions={args.active_fractions}")
|
| 969 |
+
print(f"warmup_steps_list={args.warmup_steps_list} explore_fractions={args.explore_fractions}")
|
| 970 |
+
print(f"mass_init={args.mass_init} mass_beta={args.mass_beta} ucb_alpha={args.ucb_alpha}")
|
| 971 |
+
print(f"paired_seeds={not args.unpaired_seeds}")
|
| 972 |
+
print(f"audit_every={args.audit_every} (0 means no dense audit after warmup, except oracle_current)")
|
| 973 |
+
|
| 974 |
+
tmp_model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, args.dropout).to(dev)
|
| 975 |
+
total_params, linear_params, linear_frac = parameter_fractions(tmp_model)
|
| 976 |
+
del tmp_model
|
| 977 |
+
print(f"params total={total_params} linear={linear_params} linear_fraction={linear_frac:.3f}")
|
| 978 |
+
if args.freeze_non_linear_when_sparse:
|
| 979 |
+
print("freeze_non_linear_when_sparse=True: embeddings/layernorm/etc. are frozen in sparse runs")
|
| 980 |
+
else:
|
| 981 |
+
print("freeze_non_linear_when_sparse=False: non-Linear params are still updated densely")
|
| 982 |
+
|
| 983 |
+
if args.dropout != 0.0:
|
| 984 |
+
print("warning: dropout is nonzero; dense audit and sparse training passes may see different dropout masks")
|
| 985 |
+
|
| 986 |
+
rows: List[Dict[str, float | str]] = []
|
| 987 |
+
print("\nRunning dense baseline")
|
| 988 |
+
rows.append(
|
| 989 |
+
train_run(
|
| 990 |
+
corpus,
|
| 991 |
+
args,
|
| 992 |
+
policy=None,
|
| 993 |
+
backward_mode=None,
|
| 994 |
+
active_fraction=1.0,
|
| 995 |
+
warmup_steps=0,
|
| 996 |
+
explore_fraction=0.0,
|
| 997 |
+
seed_offset=0,
|
| 998 |
+
)
|
| 999 |
+
)
|
| 1000 |
+
|
| 1001 |
+
seed_offset = 100
|
| 1002 |
+
for mode in args.backward_modes:
|
| 1003 |
+
for af in args.active_fractions:
|
| 1004 |
+
for pol in args.policies:
|
| 1005 |
+
explore_values = args.explore_fractions if pol in {"predicted_magnitude", "ucb_magnitude"} else [0.0]
|
| 1006 |
+
for warmup in args.warmup_steps_list:
|
| 1007 |
+
for explore in explore_values:
|
| 1008 |
+
print(
|
| 1009 |
+
f"\nRunning mode={mode}, policy={pol}, "
|
| 1010 |
+
f"active_fraction={af:.3f}, warmup={warmup}, explore={explore:.2f}"
|
| 1011 |
+
)
|
| 1012 |
+
rows.append(
|
| 1013 |
+
train_run(
|
| 1014 |
+
corpus,
|
| 1015 |
+
args,
|
| 1016 |
+
policy=pol, # type: ignore[arg-type]
|
| 1017 |
+
backward_mode=mode, # type: ignore[arg-type]
|
| 1018 |
+
active_fraction=af,
|
| 1019 |
+
warmup_steps=warmup,
|
| 1020 |
+
explore_fraction=explore,
|
| 1021 |
+
seed_offset=seed_offset,
|
| 1022 |
+
)
|
| 1023 |
+
)
|
| 1024 |
+
seed_offset += 1
|
| 1025 |
+
|
| 1026 |
+
print_summary(rows)
|
| 1027 |
+
|
| 1028 |
+
print("\nNotes")
|
| 1029 |
+
print(" masked_optimizer is the v7-style dense-backward simulation control.")
|
| 1030 |
+
print(" sparse_dW_full_dX uses custom Linear backward: sparse weight/bias grads, full input gradient.")
|
| 1031 |
+
print(" sparse_dW_sparse_dX uses custom Linear backward: sparse weight/bias grads and sparse input gradient.")
|
| 1032 |
+
print(" oracle_current uses dense audit gradients to choose rows; it is an upper bound.")
|
| 1033 |
+
print(" predicted_magnitude uses EMA mass from active/observed rows only.")
|
| 1034 |
+
print(" random is the sparse-support control.")
|
| 1035 |
+
print(" v9 does not compute dense audit gradients after warmup unless --audit_every > 0, except oracle_current.")
|
| 1036 |
+
print(" predicted_magnitude updates EMA statistics only from active rows observed by the training backward pass.")
|
| 1037 |
+
print(" cosine/top20/jacc are nan when --audit_every 0 because no dense reference gradient is computed.")
|
| 1038 |
+
print(" This is still not a wall-clock benchmark: PyTorch indexing may not accelerate on CPU/MPS without a custom Metal kernel.")
|
| 1039 |
+
|
| 1040 |
+
|
| 1041 |
+
if __name__ == "__main__":
|
| 1042 |
+
main()
|
experiments/surprise_topk_gradient_prototype-v2.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Surprise Top-K Gradient Prototype, v2
|
| 3 |
+
|
| 4 |
+
Goal
|
| 5 |
+
----
|
| 6 |
+
Test the hypothesis:
|
| 7 |
+
|
| 8 |
+
gradient_t ≈ predicted_gradient_t + sparse_surprising_residual_t
|
| 9 |
+
|
| 10 |
+
This version fixes two problems from v1:
|
| 11 |
+
1. The baseline now actually learns the toy task, using Adam instead of raw SGD.
|
| 12 |
+
2. The surprise method is tested as an approximate gradient passed into Adam,
|
| 13 |
+
rather than as a hand-written SGD update with unstable error-feedback buffers.
|
| 14 |
+
|
| 15 |
+
Important caveat
|
| 16 |
+
----------------
|
| 17 |
+
This still computes the full gradient every step. That is intentional. This is a
|
| 18 |
+
hypothesis test: does a sparse "surprising residual" preserve the useful update
|
| 19 |
+
signal? If yes, a later version can try to skip real backward computation.
|
| 20 |
+
|
| 21 |
+
Run
|
| 22 |
+
---
|
| 23 |
+
python3 surprise_topk_gradient_prototype.py
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import math
|
| 29 |
+
import random
|
| 30 |
+
from dataclasses import dataclass
|
| 31 |
+
from typing import Dict, List, Tuple
|
| 32 |
+
|
| 33 |
+
import torch
|
| 34 |
+
import torch.nn as nn
|
| 35 |
+
import torch.nn.functional as F
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
SEED = 7
|
| 39 |
+
random.seed(SEED)
|
| 40 |
+
torch.manual_seed(SEED)
|
| 41 |
+
|
| 42 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# -----------------------------
|
| 46 |
+
# Toy data: 2-class spiral
|
| 47 |
+
# -----------------------------
|
| 48 |
+
|
| 49 |
+
def make_spiral(n_per_class: int = 1024, noise: float = 0.12) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 50 |
+
xs = []
|
| 51 |
+
ys = []
|
| 52 |
+
|
| 53 |
+
for class_id in range(2):
|
| 54 |
+
r = torch.linspace(0.0, 1.0, n_per_class)
|
| 55 |
+
theta = class_id * math.pi + r * 4.0 * math.pi
|
| 56 |
+
theta = theta + torch.randn(n_per_class) * noise
|
| 57 |
+
|
| 58 |
+
x = torch.stack([r * torch.sin(theta), r * torch.cos(theta)], dim=1)
|
| 59 |
+
y = torch.full((n_per_class,), class_id, dtype=torch.long)
|
| 60 |
+
|
| 61 |
+
xs.append(x)
|
| 62 |
+
ys.append(y)
|
| 63 |
+
|
| 64 |
+
X = torch.cat(xs, dim=0)
|
| 65 |
+
Y = torch.cat(ys, dim=0)
|
| 66 |
+
|
| 67 |
+
# Mild scale expansion helps the MLP separate the spiral.
|
| 68 |
+
X = 3.0 * X
|
| 69 |
+
|
| 70 |
+
perm = torch.randperm(X.shape[0])
|
| 71 |
+
return X[perm].to(DEVICE), Y[perm].to(DEVICE)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# -----------------------------
|
| 75 |
+
# Model
|
| 76 |
+
# -----------------------------
|
| 77 |
+
|
| 78 |
+
class TinyMLP(nn.Module):
|
| 79 |
+
def __init__(self, width: int = 128):
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.net = nn.Sequential(
|
| 82 |
+
nn.Linear(2, width),
|
| 83 |
+
nn.ReLU(),
|
| 84 |
+
nn.Linear(width, width),
|
| 85 |
+
nn.ReLU(),
|
| 86 |
+
nn.Linear(width, width),
|
| 87 |
+
nn.ReLU(),
|
| 88 |
+
nn.Linear(width, 2),
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 92 |
+
return self.net(x)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def linear_layers(model: nn.Module) -> List[nn.Linear]:
|
| 96 |
+
return [m for m in model.modules() if isinstance(m, nn.Linear)]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# -----------------------------
|
| 100 |
+
# Surprise Top-K machinery
|
| 101 |
+
# -----------------------------
|
| 102 |
+
|
| 103 |
+
@dataclass(frozen=True)
|
| 104 |
+
class BlockRef:
|
| 105 |
+
layer_index: int
|
| 106 |
+
row_index: int
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class SurpriseTopKGradientBuilder:
|
| 110 |
+
"""
|
| 111 |
+
Builds approximate gradients after a full backward pass.
|
| 112 |
+
|
| 113 |
+
Block = one output row of a Linear weight matrix, plus its bias element.
|
| 114 |
+
|
| 115 |
+
On active blocks:
|
| 116 |
+
use true gradient.
|
| 117 |
+
|
| 118 |
+
On inactive blocks:
|
| 119 |
+
use predicted gradient from an exponential moving average.
|
| 120 |
+
|
| 121 |
+
Score:
|
| 122 |
+
surprise = ||true_gradient - predicted_gradient|| / (||true_gradient|| + eps)
|
| 123 |
+
|
| 124 |
+
The highest-surprise blocks become active on non-refresh steps.
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
def __init__(
|
| 128 |
+
self,
|
| 129 |
+
model: nn.Module,
|
| 130 |
+
beta: float = 0.95,
|
| 131 |
+
active_fraction: float = 0.2,
|
| 132 |
+
refresh_interval: int = 10,
|
| 133 |
+
warmup_steps: int = 100,
|
| 134 |
+
eps: float = 1e-12,
|
| 135 |
+
):
|
| 136 |
+
self.model = model
|
| 137 |
+
self.layers = linear_layers(model)
|
| 138 |
+
self.beta = beta
|
| 139 |
+
self.active_fraction = active_fraction
|
| 140 |
+
self.refresh_interval = refresh_interval
|
| 141 |
+
self.warmup_steps = warmup_steps
|
| 142 |
+
self.eps = eps
|
| 143 |
+
|
| 144 |
+
self.blocks: List[BlockRef] = []
|
| 145 |
+
for li, layer in enumerate(self.layers):
|
| 146 |
+
for row in range(layer.weight.shape[0]):
|
| 147 |
+
self.blocks.append(BlockRef(li, row))
|
| 148 |
+
|
| 149 |
+
self.pred_w: Dict[int, torch.Tensor] = {}
|
| 150 |
+
self.pred_b: Dict[int, torch.Tensor] = {}
|
| 151 |
+
|
| 152 |
+
for li, layer in enumerate(self.layers):
|
| 153 |
+
self.pred_w[li] = torch.zeros_like(layer.weight.data)
|
| 154 |
+
if layer.bias is not None:
|
| 155 |
+
self.pred_b[li] = torch.zeros_like(layer.bias.data)
|
| 156 |
+
|
| 157 |
+
self.scores = torch.ones(len(self.blocks), device=DEVICE)
|
| 158 |
+
|
| 159 |
+
def _choose_active_blocks(self, step: int) -> torch.Tensor:
|
| 160 |
+
n = len(self.blocks)
|
| 161 |
+
|
| 162 |
+
if step < self.warmup_steps:
|
| 163 |
+
return torch.ones(n, dtype=torch.bool, device=DEVICE)
|
| 164 |
+
|
| 165 |
+
if step % self.refresh_interval == 0:
|
| 166 |
+
return torch.ones(n, dtype=torch.bool, device=DEVICE)
|
| 167 |
+
|
| 168 |
+
k = max(1, int(self.active_fraction * n))
|
| 169 |
+
active = torch.zeros(n, dtype=torch.bool, device=DEVICE)
|
| 170 |
+
idx = torch.topk(self.scores, k=k).indices
|
| 171 |
+
active[idx] = True
|
| 172 |
+
return active
|
| 173 |
+
|
| 174 |
+
@torch.no_grad()
|
| 175 |
+
def build_and_install_approx_grads(self, step: int) -> Dict[str, float]:
|
| 176 |
+
active = self._choose_active_blocks(step)
|
| 177 |
+
|
| 178 |
+
true_parts = []
|
| 179 |
+
approx_parts = []
|
| 180 |
+
pred_parts = []
|
| 181 |
+
|
| 182 |
+
# We build full approximate grad tensors, then overwrite .grad so Adam sees
|
| 183 |
+
# the approximate gradient.
|
| 184 |
+
approx_w: Dict[int, torch.Tensor] = {}
|
| 185 |
+
approx_b: Dict[int, torch.Tensor] = {}
|
| 186 |
+
for li, layer in enumerate(self.layers):
|
| 187 |
+
approx_w[li] = torch.zeros_like(layer.weight.grad)
|
| 188 |
+
if layer.bias is not None:
|
| 189 |
+
approx_b[li] = torch.zeros_like(layer.bias.grad)
|
| 190 |
+
|
| 191 |
+
for block_id, block in enumerate(self.blocks):
|
| 192 |
+
li = block.layer_index
|
| 193 |
+
row = block.row_index
|
| 194 |
+
layer = self.layers[li]
|
| 195 |
+
is_active = bool(active[block_id].item())
|
| 196 |
+
|
| 197 |
+
g_w = layer.weight.grad[row].detach().clone()
|
| 198 |
+
p_w = self.pred_w[li][row].detach().clone()
|
| 199 |
+
|
| 200 |
+
if layer.bias is not None:
|
| 201 |
+
g_b = layer.bias.grad[row].detach().clone()
|
| 202 |
+
p_b = self.pred_b[li][row].detach().clone()
|
| 203 |
+
else:
|
| 204 |
+
g_b = None
|
| 205 |
+
p_b = None
|
| 206 |
+
|
| 207 |
+
# Score is computed against the predictor BEFORE updating predictor.
|
| 208 |
+
true_vec_items = [g_w.flatten()]
|
| 209 |
+
pred_vec_items = [p_w.flatten()]
|
| 210 |
+
if g_b is not None:
|
| 211 |
+
true_vec_items.append(g_b.view(1))
|
| 212 |
+
pred_vec_items.append(p_b.view(1))
|
| 213 |
+
|
| 214 |
+
true_vec_block = torch.cat(true_vec_items)
|
| 215 |
+
pred_vec_block = torch.cat(pred_vec_items)
|
| 216 |
+
residual = true_vec_block - pred_vec_block
|
| 217 |
+
self.scores[block_id] = torch.norm(residual) / (torch.norm(true_vec_block) + self.eps)
|
| 218 |
+
|
| 219 |
+
if is_active:
|
| 220 |
+
a_w = g_w
|
| 221 |
+
a_b = g_b
|
| 222 |
+
else:
|
| 223 |
+
a_w = p_w
|
| 224 |
+
a_b = p_b
|
| 225 |
+
|
| 226 |
+
approx_w[li][row] = a_w
|
| 227 |
+
if layer.bias is not None:
|
| 228 |
+
approx_b[li][row] = a_b
|
| 229 |
+
|
| 230 |
+
# Update EMA predictor from the true gradient. We allow this in the
|
| 231 |
+
# prototype because full gradients are being computed for measurement.
|
| 232 |
+
# A speedup version would only update this from refresh/active blocks.
|
| 233 |
+
self.pred_w[li][row].mul_(self.beta).add_(g_w, alpha=1.0 - self.beta)
|
| 234 |
+
if layer.bias is not None:
|
| 235 |
+
self.pred_b[li][row].mul_(self.beta).add_(g_b, alpha=1.0 - self.beta)
|
| 236 |
+
|
| 237 |
+
approx_vec_items = [a_w.flatten()]
|
| 238 |
+
if a_b is not None:
|
| 239 |
+
approx_vec_items.append(a_b.view(1))
|
| 240 |
+
|
| 241 |
+
true_parts.append(true_vec_block)
|
| 242 |
+
pred_parts.append(pred_vec_block)
|
| 243 |
+
approx_parts.append(torch.cat(approx_vec_items))
|
| 244 |
+
|
| 245 |
+
# Install approximate gradients for the optimizer.
|
| 246 |
+
for li, layer in enumerate(self.layers):
|
| 247 |
+
layer.weight.grad.copy_(approx_w[li])
|
| 248 |
+
if layer.bias is not None:
|
| 249 |
+
layer.bias.grad.copy_(approx_b[li])
|
| 250 |
+
|
| 251 |
+
true_vec = torch.cat(true_parts)
|
| 252 |
+
pred_vec = torch.cat(pred_parts)
|
| 253 |
+
approx_vec = torch.cat(approx_parts)
|
| 254 |
+
|
| 255 |
+
cosine = F.cosine_similarity(true_vec, approx_vec, dim=0).item()
|
| 256 |
+
pred_explained = 1.0 - (
|
| 257 |
+
torch.norm(true_vec - pred_vec).pow(2) / (torch.norm(true_vec).pow(2) + self.eps)
|
| 258 |
+
).item()
|
| 259 |
+
|
| 260 |
+
k20 = max(1, int(0.2 * len(self.blocks)))
|
| 261 |
+
sorted_scores = torch.sort(self.scores.detach(), descending=True).values
|
| 262 |
+
top20_mass = (sorted_scores[:k20].sum() / (sorted_scores.sum() + self.eps)).item()
|
| 263 |
+
|
| 264 |
+
return {
|
| 265 |
+
"active_fraction": float(active.float().mean().item()),
|
| 266 |
+
"cosine_true_vs_approx": cosine,
|
| 267 |
+
"pred_explained_fraction": pred_explained,
|
| 268 |
+
"top20_surprise_mass": top20_mass,
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# -----------------------------
|
| 273 |
+
# Metrics and training
|
| 274 |
+
# -----------------------------
|
| 275 |
+
|
| 276 |
+
def accuracy(model: nn.Module, X: torch.Tensor, y: torch.Tensor) -> float:
|
| 277 |
+
model.eval()
|
| 278 |
+
with torch.no_grad():
|
| 279 |
+
pred = model(X).argmax(dim=1)
|
| 280 |
+
return (pred == y).float().mean().item()
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def train_baseline(
|
| 284 |
+
X: torch.Tensor,
|
| 285 |
+
y: torch.Tensor,
|
| 286 |
+
steps: int = 2000,
|
| 287 |
+
batch_size: int = 256,
|
| 288 |
+
lr: float = 1e-3,
|
| 289 |
+
) -> Tuple[nn.Module, List[Dict[str, float]]]:
|
| 290 |
+
model = TinyMLP().to(DEVICE)
|
| 291 |
+
opt = torch.optim.Adam(model.parameters(), lr=lr)
|
| 292 |
+
history: List[Dict[str, float]] = []
|
| 293 |
+
|
| 294 |
+
for step in range(steps):
|
| 295 |
+
model.train()
|
| 296 |
+
idx = torch.randint(0, X.shape[0], (batch_size,), device=DEVICE)
|
| 297 |
+
xb, yb = X[idx], y[idx]
|
| 298 |
+
|
| 299 |
+
loss = F.cross_entropy(model(xb), yb)
|
| 300 |
+
|
| 301 |
+
opt.zero_grad(set_to_none=True)
|
| 302 |
+
loss.backward()
|
| 303 |
+
opt.step()
|
| 304 |
+
|
| 305 |
+
if step % 100 == 0 or step == steps - 1:
|
| 306 |
+
history.append({
|
| 307 |
+
"step": step,
|
| 308 |
+
"loss": float(loss.item()),
|
| 309 |
+
"accuracy": accuracy(model, X, y),
|
| 310 |
+
})
|
| 311 |
+
|
| 312 |
+
return model, history
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def train_surprise_topk(
|
| 316 |
+
X: torch.Tensor,
|
| 317 |
+
y: torch.Tensor,
|
| 318 |
+
steps: int = 2000,
|
| 319 |
+
batch_size: int = 256,
|
| 320 |
+
lr: float = 1e-3,
|
| 321 |
+
active_fraction: float = 0.2,
|
| 322 |
+
refresh_interval: int = 10,
|
| 323 |
+
warmup_steps: int = 100,
|
| 324 |
+
beta: float = 0.95,
|
| 325 |
+
) -> Tuple[nn.Module, List[Dict[str, float]]]:
|
| 326 |
+
model = TinyMLP().to(DEVICE)
|
| 327 |
+
opt = torch.optim.Adam(model.parameters(), lr=lr)
|
| 328 |
+
builder = SurpriseTopKGradientBuilder(
|
| 329 |
+
model,
|
| 330 |
+
beta=beta,
|
| 331 |
+
active_fraction=active_fraction,
|
| 332 |
+
refresh_interval=refresh_interval,
|
| 333 |
+
warmup_steps=warmup_steps,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
history: List[Dict[str, float]] = []
|
| 337 |
+
|
| 338 |
+
for step in range(steps):
|
| 339 |
+
model.train()
|
| 340 |
+
idx = torch.randint(0, X.shape[0], (batch_size,), device=DEVICE)
|
| 341 |
+
xb, yb = X[idx], y[idx]
|
| 342 |
+
|
| 343 |
+
loss = F.cross_entropy(model(xb), yb)
|
| 344 |
+
|
| 345 |
+
opt.zero_grad(set_to_none=True)
|
| 346 |
+
loss.backward()
|
| 347 |
+
|
| 348 |
+
diagnostics = builder.build_and_install_approx_grads(step)
|
| 349 |
+
opt.step()
|
| 350 |
+
|
| 351 |
+
if step % 100 == 0 or step == steps - 1:
|
| 352 |
+
history.append({
|
| 353 |
+
"step": step,
|
| 354 |
+
"loss": float(loss.item()),
|
| 355 |
+
"accuracy": accuracy(model, X, y),
|
| 356 |
+
**diagnostics,
|
| 357 |
+
})
|
| 358 |
+
|
| 359 |
+
return model, history
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def print_last(label: str, history: List[Dict[str, float]]) -> None:
|
| 363 |
+
print(f"\n{label}")
|
| 364 |
+
for k, v in history[-1].items():
|
| 365 |
+
if isinstance(v, float):
|
| 366 |
+
print(f" {k:28s}: {v:.4f}")
|
| 367 |
+
else:
|
| 368 |
+
print(f" {k:28s}: {v}")
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def print_checkpoints(label: str, history: List[Dict[str, float]]) -> None:
|
| 372 |
+
print(f"\n{label} checkpoints:")
|
| 373 |
+
stride = max(1, len(history) // 8)
|
| 374 |
+
for row in history[::stride]:
|
| 375 |
+
extra = ""
|
| 376 |
+
if "cosine_true_vs_approx" in row:
|
| 377 |
+
extra = (
|
| 378 |
+
f" active={row['active_fraction']:.2f}"
|
| 379 |
+
f" cos={row['cosine_true_vs_approx']:.3f}"
|
| 380 |
+
f" pred_expl={row['pred_explained_fraction']:.3f}"
|
| 381 |
+
f" top20_mass={row['top20_surprise_mass']:.3f}"
|
| 382 |
+
)
|
| 383 |
+
print(
|
| 384 |
+
f"step={row['step']:4d} "
|
| 385 |
+
f"loss={row['loss']:.4f} "
|
| 386 |
+
f"acc={row['accuracy']:.3f}"
|
| 387 |
+
f"{extra}"
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def main() -> None:
|
| 392 |
+
X, y = make_spiral()
|
| 393 |
+
|
| 394 |
+
baseline_model, baseline_hist = train_baseline(X, y)
|
| 395 |
+
surprise_model, surprise_hist = train_surprise_topk(
|
| 396 |
+
X,
|
| 397 |
+
y,
|
| 398 |
+
active_fraction=0.2,
|
| 399 |
+
refresh_interval=10,
|
| 400 |
+
warmup_steps=100,
|
| 401 |
+
beta=0.95,
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
print_last("Baseline full Adam", baseline_hist)
|
| 405 |
+
print_last("Surprise Top-K simulated Adam", surprise_hist)
|
| 406 |
+
|
| 407 |
+
print_checkpoints("Baseline", baseline_hist)
|
| 408 |
+
print_checkpoints("Surprise Top-K", surprise_hist)
|
| 409 |
+
|
| 410 |
+
print("\nHow to read this:")
|
| 411 |
+
print(" cos near 1.0 => approximate update points like the true gradient")
|
| 412 |
+
print(" pred_expl > 0 => predictor beats zero as a gradient guess")
|
| 413 |
+
print(" top20_mass high => surprise is heavy-tailed / concentrated")
|
| 414 |
+
print(" accuracy close => approximation did not wreck training")
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
if __name__ == "__main__":
|
| 418 |
+
main()
|
experiments/surprise_topk_gradient_prototype-v3.py
ADDED
|
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Surprise Top-K Gradient Prototype, v3
|
| 3 |
+
|
| 4 |
+
Goal
|
| 5 |
+
----
|
| 6 |
+
Test the hypothesis:
|
| 7 |
+
|
| 8 |
+
gradient_t ≈ predicted_gradient_t + sparse_surprising_residual_t
|
| 9 |
+
|
| 10 |
+
What changed from v2
|
| 11 |
+
--------------------
|
| 12 |
+
1. Checkpoints no longer accidentally land mostly on refresh steps.
|
| 13 |
+
2. Surprise mass is measured using absolute residual norm, not a normalized ratio
|
| 14 |
+
that explodes when true gradients are tiny.
|
| 15 |
+
3. We track both:
|
| 16 |
+
- score_for_selection: normalized surprise, used to decide active blocks
|
| 17 |
+
- residual_mass: absolute residual norm, used to test heavy-tailed structure
|
| 18 |
+
4. We print separate summaries for refresh steps and sparse steps.
|
| 19 |
+
5. We compare surprise-top-k against magnitude-top-k and random-top-k policies.
|
| 20 |
+
|
| 21 |
+
Important caveat
|
| 22 |
+
----------------
|
| 23 |
+
This still computes the full gradient every step. That is intentional. This is a
|
| 24 |
+
hypothesis test: does a sparse "surprising residual" preserve the useful update
|
| 25 |
+
signal? If yes, a later version can try to skip real backward computation.
|
| 26 |
+
|
| 27 |
+
Run
|
| 28 |
+
---
|
| 29 |
+
python3 surprise_topk_gradient_prototype.py
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
from __future__ import annotations
|
| 33 |
+
|
| 34 |
+
import math
|
| 35 |
+
import random
|
| 36 |
+
from dataclasses import dataclass
|
| 37 |
+
from typing import Dict, List, Literal, Tuple
|
| 38 |
+
|
| 39 |
+
import torch
|
| 40 |
+
import torch.nn as nn
|
| 41 |
+
import torch.nn.functional as F
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
SEED = 7
|
| 45 |
+
random.seed(SEED)
|
| 46 |
+
torch.manual_seed(SEED)
|
| 47 |
+
|
| 48 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 49 |
+
Policy = Literal["surprise", "magnitude", "random"]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# -----------------------------
|
| 53 |
+
# Toy data: 2-class spiral
|
| 54 |
+
# -----------------------------
|
| 55 |
+
|
| 56 |
+
def make_spiral(n_per_class: int = 1024, noise: float = 0.12) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 57 |
+
xs = []
|
| 58 |
+
ys = []
|
| 59 |
+
|
| 60 |
+
for class_id in range(2):
|
| 61 |
+
r = torch.linspace(0.0, 1.0, n_per_class)
|
| 62 |
+
theta = class_id * math.pi + r * 4.0 * math.pi
|
| 63 |
+
theta = theta + torch.randn(n_per_class) * noise
|
| 64 |
+
|
| 65 |
+
x = torch.stack([r * torch.sin(theta), r * torch.cos(theta)], dim=1)
|
| 66 |
+
y = torch.full((n_per_class,), class_id, dtype=torch.long)
|
| 67 |
+
|
| 68 |
+
xs.append(x)
|
| 69 |
+
ys.append(y)
|
| 70 |
+
|
| 71 |
+
X = torch.cat(xs, dim=0)
|
| 72 |
+
Y = torch.cat(ys, dim=0)
|
| 73 |
+
X = 3.0 * X
|
| 74 |
+
|
| 75 |
+
perm = torch.randperm(X.shape[0])
|
| 76 |
+
return X[perm].to(DEVICE), Y[perm].to(DEVICE)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# -----------------------------
|
| 80 |
+
# Model
|
| 81 |
+
# -----------------------------
|
| 82 |
+
|
| 83 |
+
class TinyMLP(nn.Module):
|
| 84 |
+
def __init__(self, width: int = 128):
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.net = nn.Sequential(
|
| 87 |
+
nn.Linear(2, width),
|
| 88 |
+
nn.ReLU(),
|
| 89 |
+
nn.Linear(width, width),
|
| 90 |
+
nn.ReLU(),
|
| 91 |
+
nn.Linear(width, width),
|
| 92 |
+
nn.ReLU(),
|
| 93 |
+
nn.Linear(width, 2),
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 97 |
+
return self.net(x)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def linear_layers(model: nn.Module) -> List[nn.Linear]:
|
| 101 |
+
return [m for m in model.modules() if isinstance(m, nn.Linear)]
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# -----------------------------
|
| 105 |
+
# Surprise Top-K machinery
|
| 106 |
+
# -----------------------------
|
| 107 |
+
|
| 108 |
+
@dataclass(frozen=True)
|
| 109 |
+
class BlockRef:
|
| 110 |
+
layer_index: int
|
| 111 |
+
row_index: int
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class SparseGradientBuilder:
|
| 115 |
+
"""
|
| 116 |
+
Builds approximate gradients after a full backward pass.
|
| 117 |
+
|
| 118 |
+
Block = one output row of a Linear weight matrix, plus its bias element.
|
| 119 |
+
|
| 120 |
+
Approx gradient:
|
| 121 |
+
active blocks -> true gradient
|
| 122 |
+
inactive blocks -> EMA-predicted gradient
|
| 123 |
+
|
| 124 |
+
Selection policies:
|
| 125 |
+
surprise -> choose blocks where true gradient differs most from prediction
|
| 126 |
+
magnitude -> choose blocks with largest true gradient norm
|
| 127 |
+
random -> choose random blocks
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
def __init__(
|
| 131 |
+
self,
|
| 132 |
+
model: nn.Module,
|
| 133 |
+
policy: Policy = "surprise",
|
| 134 |
+
beta: float = 0.95,
|
| 135 |
+
active_fraction: float = 0.2,
|
| 136 |
+
refresh_interval: int = 10,
|
| 137 |
+
warmup_steps: int = 100,
|
| 138 |
+
eps: float = 1e-12,
|
| 139 |
+
):
|
| 140 |
+
self.model = model
|
| 141 |
+
self.layers = linear_layers(model)
|
| 142 |
+
self.policy = policy
|
| 143 |
+
self.beta = beta
|
| 144 |
+
self.active_fraction = active_fraction
|
| 145 |
+
self.refresh_interval = refresh_interval
|
| 146 |
+
self.warmup_steps = warmup_steps
|
| 147 |
+
self.eps = eps
|
| 148 |
+
|
| 149 |
+
self.blocks: List[BlockRef] = []
|
| 150 |
+
for li, layer in enumerate(self.layers):
|
| 151 |
+
for row in range(layer.weight.shape[0]):
|
| 152 |
+
self.blocks.append(BlockRef(li, row))
|
| 153 |
+
|
| 154 |
+
self.pred_w: Dict[int, torch.Tensor] = {}
|
| 155 |
+
self.pred_b: Dict[int, torch.Tensor] = {}
|
| 156 |
+
|
| 157 |
+
for li, layer in enumerate(self.layers):
|
| 158 |
+
self.pred_w[li] = torch.zeros_like(layer.weight.data)
|
| 159 |
+
if layer.bias is not None:
|
| 160 |
+
self.pred_b[li] = torch.zeros_like(layer.bias.data)
|
| 161 |
+
|
| 162 |
+
self.selection_score = torch.ones(len(self.blocks), device=DEVICE)
|
| 163 |
+
self.residual_mass = torch.ones(len(self.blocks), device=DEVICE)
|
| 164 |
+
self.gradient_mass = torch.ones(len(self.blocks), device=DEVICE)
|
| 165 |
+
|
| 166 |
+
def _is_refresh_step(self, step: int) -> bool:
|
| 167 |
+
return step < self.warmup_steps or step % self.refresh_interval == 0
|
| 168 |
+
|
| 169 |
+
def _choose_active_blocks(self, step: int) -> torch.Tensor:
|
| 170 |
+
n = len(self.blocks)
|
| 171 |
+
|
| 172 |
+
if self._is_refresh_step(step):
|
| 173 |
+
return torch.ones(n, dtype=torch.bool, device=DEVICE)
|
| 174 |
+
|
| 175 |
+
k = max(1, int(self.active_fraction * n))
|
| 176 |
+
active = torch.zeros(n, dtype=torch.bool, device=DEVICE)
|
| 177 |
+
|
| 178 |
+
if self.policy == "surprise":
|
| 179 |
+
idx = torch.topk(self.selection_score, k=k).indices
|
| 180 |
+
elif self.policy == "magnitude":
|
| 181 |
+
idx = torch.topk(self.gradient_mass, k=k).indices
|
| 182 |
+
elif self.policy == "random":
|
| 183 |
+
idx = torch.randperm(n, device=DEVICE)[:k]
|
| 184 |
+
else:
|
| 185 |
+
raise ValueError(f"Unknown policy: {self.policy}")
|
| 186 |
+
|
| 187 |
+
active[idx] = True
|
| 188 |
+
return active
|
| 189 |
+
|
| 190 |
+
@torch.no_grad()
|
| 191 |
+
def build_and_install_approx_grads(self, step: int) -> Dict[str, float]:
|
| 192 |
+
active = self._choose_active_blocks(step)
|
| 193 |
+
is_refresh = self._is_refresh_step(step)
|
| 194 |
+
|
| 195 |
+
true_parts = []
|
| 196 |
+
approx_parts = []
|
| 197 |
+
pred_parts = []
|
| 198 |
+
|
| 199 |
+
approx_w: Dict[int, torch.Tensor] = {}
|
| 200 |
+
approx_b: Dict[int, torch.Tensor] = {}
|
| 201 |
+
for li, layer in enumerate(self.layers):
|
| 202 |
+
approx_w[li] = torch.zeros_like(layer.weight.grad)
|
| 203 |
+
if layer.bias is not None:
|
| 204 |
+
approx_b[li] = torch.zeros_like(layer.bias.grad)
|
| 205 |
+
|
| 206 |
+
for block_id, block in enumerate(self.blocks):
|
| 207 |
+
li = block.layer_index
|
| 208 |
+
row = block.row_index
|
| 209 |
+
layer = self.layers[li]
|
| 210 |
+
is_active = bool(active[block_id].item())
|
| 211 |
+
|
| 212 |
+
g_w = layer.weight.grad[row].detach().clone()
|
| 213 |
+
p_w = self.pred_w[li][row].detach().clone()
|
| 214 |
+
|
| 215 |
+
if layer.bias is not None:
|
| 216 |
+
g_b = layer.bias.grad[row].detach().clone()
|
| 217 |
+
p_b = self.pred_b[li][row].detach().clone()
|
| 218 |
+
else:
|
| 219 |
+
g_b = None
|
| 220 |
+
p_b = None
|
| 221 |
+
|
| 222 |
+
true_vec_items = [g_w.flatten()]
|
| 223 |
+
pred_vec_items = [p_w.flatten()]
|
| 224 |
+
if g_b is not None:
|
| 225 |
+
true_vec_items.append(g_b.view(1))
|
| 226 |
+
pred_vec_items.append(p_b.view(1))
|
| 227 |
+
|
| 228 |
+
true_vec_block = torch.cat(true_vec_items)
|
| 229 |
+
pred_vec_block = torch.cat(pred_vec_items)
|
| 230 |
+
residual = true_vec_block - pred_vec_block
|
| 231 |
+
|
| 232 |
+
grad_norm = torch.norm(true_vec_block)
|
| 233 |
+
residual_norm = torch.norm(residual)
|
| 234 |
+
|
| 235 |
+
self.gradient_mass[block_id] = grad_norm
|
| 236 |
+
self.residual_mass[block_id] = residual_norm
|
| 237 |
+
|
| 238 |
+
# Selection score deliberately normalized, but with a floor so tiny
|
| 239 |
+
# gradients do not create absurd ratios.
|
| 240 |
+
denom = torch.maximum(grad_norm, torch.tensor(1e-6, device=DEVICE))
|
| 241 |
+
self.selection_score[block_id] = residual_norm / denom
|
| 242 |
+
|
| 243 |
+
if is_active:
|
| 244 |
+
a_w = g_w
|
| 245 |
+
a_b = g_b
|
| 246 |
+
else:
|
| 247 |
+
a_w = p_w
|
| 248 |
+
a_b = p_b
|
| 249 |
+
|
| 250 |
+
approx_w[li][row] = a_w
|
| 251 |
+
if layer.bias is not None:
|
| 252 |
+
approx_b[li][row] = a_b
|
| 253 |
+
|
| 254 |
+
# Prototype choice: update EMA from true gradient on every step,
|
| 255 |
+
# because we computed full gradients for measurement. A speedup version
|
| 256 |
+
# would only update this on refresh/active blocks.
|
| 257 |
+
self.pred_w[li][row].mul_(self.beta).add_(g_w, alpha=1.0 - self.beta)
|
| 258 |
+
if layer.bias is not None:
|
| 259 |
+
self.pred_b[li][row].mul_(self.beta).add_(g_b, alpha=1.0 - self.beta)
|
| 260 |
+
|
| 261 |
+
approx_vec_items = [a_w.flatten()]
|
| 262 |
+
if a_b is not None:
|
| 263 |
+
approx_vec_items.append(a_b.view(1))
|
| 264 |
+
|
| 265 |
+
true_parts.append(true_vec_block)
|
| 266 |
+
pred_parts.append(pred_vec_block)
|
| 267 |
+
approx_parts.append(torch.cat(approx_vec_items))
|
| 268 |
+
|
| 269 |
+
for li, layer in enumerate(self.layers):
|
| 270 |
+
layer.weight.grad.copy_(approx_w[li])
|
| 271 |
+
if layer.bias is not None:
|
| 272 |
+
layer.bias.grad.copy_(approx_b[li])
|
| 273 |
+
|
| 274 |
+
true_vec = torch.cat(true_parts)
|
| 275 |
+
pred_vec = torch.cat(pred_parts)
|
| 276 |
+
approx_vec = torch.cat(approx_parts)
|
| 277 |
+
|
| 278 |
+
true_norm = torch.norm(true_vec)
|
| 279 |
+
pred_error_norm = torch.norm(true_vec - pred_vec)
|
| 280 |
+
|
| 281 |
+
cosine = F.cosine_similarity(true_vec, approx_vec, dim=0).item()
|
| 282 |
+
pred_explained = 1.0 - (
|
| 283 |
+
pred_error_norm.pow(2) / (true_norm.pow(2) + self.eps)
|
| 284 |
+
).item()
|
| 285 |
+
|
| 286 |
+
k20 = max(1, int(0.2 * len(self.blocks)))
|
| 287 |
+
|
| 288 |
+
sorted_residual = torch.sort(self.residual_mass.detach(), descending=True).values
|
| 289 |
+
top20_residual_mass = (sorted_residual[:k20].sum() / (sorted_residual.sum() + self.eps)).item()
|
| 290 |
+
|
| 291 |
+
sorted_gradient = torch.sort(self.gradient_mass.detach(), descending=True).values
|
| 292 |
+
top20_gradient_mass = (sorted_gradient[:k20].sum() / (sorted_gradient.sum() + self.eps)).item()
|
| 293 |
+
|
| 294 |
+
return {
|
| 295 |
+
"is_refresh": float(is_refresh),
|
| 296 |
+
"active_fraction": float(active.float().mean().item()),
|
| 297 |
+
"cosine_true_vs_approx": cosine,
|
| 298 |
+
"pred_explained_fraction": pred_explained,
|
| 299 |
+
"top20_residual_mass": top20_residual_mass,
|
| 300 |
+
"top20_gradient_mass": top20_gradient_mass,
|
| 301 |
+
"true_grad_norm": float(true_norm.item()),
|
| 302 |
+
"pred_error_norm": float(pred_error_norm.item()),
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
# -----------------------------
|
| 307 |
+
# Metrics and training
|
| 308 |
+
# -----------------------------
|
| 309 |
+
|
| 310 |
+
def accuracy(model: nn.Module, X: torch.Tensor, y: torch.Tensor) -> float:
|
| 311 |
+
model.eval()
|
| 312 |
+
with torch.no_grad():
|
| 313 |
+
pred = model(X).argmax(dim=1)
|
| 314 |
+
return (pred == y).float().mean().item()
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def train_baseline(
|
| 318 |
+
X: torch.Tensor,
|
| 319 |
+
y: torch.Tensor,
|
| 320 |
+
steps: int = 2000,
|
| 321 |
+
batch_size: int = 256,
|
| 322 |
+
lr: float = 1e-3,
|
| 323 |
+
) -> Tuple[nn.Module, List[Dict[str, float]]]:
|
| 324 |
+
model = TinyMLP().to(DEVICE)
|
| 325 |
+
opt = torch.optim.Adam(model.parameters(), lr=lr)
|
| 326 |
+
history: List[Dict[str, float]] = []
|
| 327 |
+
|
| 328 |
+
for step in range(steps):
|
| 329 |
+
model.train()
|
| 330 |
+
idx = torch.randint(0, X.shape[0], (batch_size,), device=DEVICE)
|
| 331 |
+
xb, yb = X[idx], y[idx]
|
| 332 |
+
|
| 333 |
+
loss = F.cross_entropy(model(xb), yb)
|
| 334 |
+
|
| 335 |
+
opt.zero_grad(set_to_none=True)
|
| 336 |
+
loss.backward()
|
| 337 |
+
opt.step()
|
| 338 |
+
|
| 339 |
+
if step % 97 == 0 or step == steps - 1:
|
| 340 |
+
history.append({
|
| 341 |
+
"step": step,
|
| 342 |
+
"loss": float(loss.item()),
|
| 343 |
+
"accuracy": accuracy(model, X, y),
|
| 344 |
+
})
|
| 345 |
+
|
| 346 |
+
return model, history
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def train_sparse_policy(
|
| 350 |
+
X: torch.Tensor,
|
| 351 |
+
y: torch.Tensor,
|
| 352 |
+
policy: Policy,
|
| 353 |
+
steps: int = 2000,
|
| 354 |
+
batch_size: int = 256,
|
| 355 |
+
lr: float = 1e-3,
|
| 356 |
+
active_fraction: float = 0.2,
|
| 357 |
+
refresh_interval: int = 10,
|
| 358 |
+
warmup_steps: int = 100,
|
| 359 |
+
beta: float = 0.95,
|
| 360 |
+
) -> Tuple[nn.Module, List[Dict[str, float]]]:
|
| 361 |
+
model = TinyMLP().to(DEVICE)
|
| 362 |
+
opt = torch.optim.Adam(model.parameters(), lr=lr)
|
| 363 |
+
builder = SparseGradientBuilder(
|
| 364 |
+
model,
|
| 365 |
+
policy=policy,
|
| 366 |
+
beta=beta,
|
| 367 |
+
active_fraction=active_fraction,
|
| 368 |
+
refresh_interval=refresh_interval,
|
| 369 |
+
warmup_steps=warmup_steps,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
history: List[Dict[str, float]] = []
|
| 373 |
+
|
| 374 |
+
for step in range(steps):
|
| 375 |
+
model.train()
|
| 376 |
+
idx = torch.randint(0, X.shape[0], (batch_size,), device=DEVICE)
|
| 377 |
+
xb, yb = X[idx], y[idx]
|
| 378 |
+
|
| 379 |
+
loss = F.cross_entropy(model(xb), yb)
|
| 380 |
+
|
| 381 |
+
opt.zero_grad(set_to_none=True)
|
| 382 |
+
loss.backward()
|
| 383 |
+
|
| 384 |
+
diagnostics = builder.build_and_install_approx_grads(step)
|
| 385 |
+
opt.step()
|
| 386 |
+
|
| 387 |
+
if step % 97 == 0 or step == steps - 1:
|
| 388 |
+
history.append({
|
| 389 |
+
"step": step,
|
| 390 |
+
"loss": float(loss.item()),
|
| 391 |
+
"accuracy": accuracy(model, X, y),
|
| 392 |
+
**diagnostics,
|
| 393 |
+
})
|
| 394 |
+
|
| 395 |
+
return model, history
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def avg_sparse_metric(history: List[Dict[str, float]], key: str) -> float:
|
| 399 |
+
vals = [row[key] for row in history if row.get("is_refresh", 0.0) == 0.0]
|
| 400 |
+
if not vals:
|
| 401 |
+
return float("nan")
|
| 402 |
+
return sum(vals) / len(vals)
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def print_last(label: str, history: List[Dict[str, float]]) -> None:
|
| 406 |
+
print(f"\n{label}")
|
| 407 |
+
for k, v in history[-1].items():
|
| 408 |
+
if isinstance(v, float):
|
| 409 |
+
print(f" {k:28s}: {v:.4f}")
|
| 410 |
+
else:
|
| 411 |
+
print(f" {k:28s}: {v}")
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def print_sparse_summary(label: str, history: List[Dict[str, float]]) -> None:
|
| 415 |
+
last = history[-1]
|
| 416 |
+
print(f"\n{label} sparse-step averages from logged checkpoints")
|
| 417 |
+
for key in [
|
| 418 |
+
"cosine_true_vs_approx",
|
| 419 |
+
"pred_explained_fraction",
|
| 420 |
+
"top20_residual_mass",
|
| 421 |
+
"top20_gradient_mass",
|
| 422 |
+
]:
|
| 423 |
+
print(f" avg {key:24s}: {avg_sparse_metric(history, key):.4f}")
|
| 424 |
+
print(f" final accuracy : {last['accuracy']:.4f}")
|
| 425 |
+
print(f" final loss : {last['loss']:.4f}")
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def print_checkpoints(label: str, history: List[Dict[str, float]]) -> None:
|
| 429 |
+
print(f"\n{label} checkpoints:")
|
| 430 |
+
stride = max(1, len(history) // 8)
|
| 431 |
+
for row in history[::stride]:
|
| 432 |
+
extra = ""
|
| 433 |
+
if "cosine_true_vs_approx" in row:
|
| 434 |
+
extra = (
|
| 435 |
+
f" refresh={int(row['is_refresh'])}"
|
| 436 |
+
f" active={row['active_fraction']:.2f}"
|
| 437 |
+
f" cos={row['cosine_true_vs_approx']:.3f}"
|
| 438 |
+
f" pred_expl={row['pred_explained_fraction']:.3f}"
|
| 439 |
+
f" top20_resid={row['top20_residual_mass']:.3f}"
|
| 440 |
+
f" top20_grad={row['top20_gradient_mass']:.3f}"
|
| 441 |
+
)
|
| 442 |
+
print(
|
| 443 |
+
f"step={row['step']:4d} "
|
| 444 |
+
f"loss={row['loss']:.4f} "
|
| 445 |
+
f"acc={row['accuracy']:.3f}"
|
| 446 |
+
f"{extra}"
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def main() -> None:
|
| 451 |
+
X, y = make_spiral()
|
| 452 |
+
|
| 453 |
+
baseline_model, baseline_hist = train_baseline(X, y)
|
| 454 |
+
|
| 455 |
+
results = {}
|
| 456 |
+
for policy in ["surprise", "magnitude", "random"]:
|
| 457 |
+
_, hist = train_sparse_policy(
|
| 458 |
+
X,
|
| 459 |
+
y,
|
| 460 |
+
policy=policy,
|
| 461 |
+
active_fraction=0.2,
|
| 462 |
+
refresh_interval=10,
|
| 463 |
+
warmup_steps=100,
|
| 464 |
+
beta=0.95,
|
| 465 |
+
)
|
| 466 |
+
results[policy] = hist
|
| 467 |
+
|
| 468 |
+
print_last("Baseline full Adam", baseline_hist)
|
| 469 |
+
for policy, hist in results.items():
|
| 470 |
+
print_last(f"{policy.title()} Top-K simulated Adam", hist)
|
| 471 |
+
|
| 472 |
+
print_checkpoints("Baseline", baseline_hist)
|
| 473 |
+
for policy, hist in results.items():
|
| 474 |
+
print_checkpoints(f"{policy.title()} Top-K", hist)
|
| 475 |
+
print_sparse_summary(f"{policy.title()} Top-K", hist)
|
| 476 |
+
|
| 477 |
+
print("\nHow to read this:")
|
| 478 |
+
print(" refresh=0 => a real sparse/approximate logged step")
|
| 479 |
+
print(" cos near 1.0 => approximate update points like the true gradient")
|
| 480 |
+
print(" pred_expl > 0 => predictor beats zero as a gradient guess")
|
| 481 |
+
print(" top20_resid high => prediction error is heavy-tailed/concentrated")
|
| 482 |
+
print(" top20_grad high => raw gradient mass is heavy-tailed/concentrated")
|
| 483 |
+
print(" surprise > baselines => surprise is doing more than ordinary top-k/random")
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
if __name__ == "__main__":
|
| 487 |
+
main()
|
experiments/surprise_topk_gradient_prototype-v4.py
ADDED
|
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Surprise / Predicted-Magnitude Top-K Gradient Prototype, v4
|
| 3 |
+
|
| 4 |
+
Goal
|
| 5 |
+
----
|
| 6 |
+
Test the practical version of the hypothesis:
|
| 7 |
+
|
| 8 |
+
The gradient/update signal is heavy-tailed, and the high-mass blocks are
|
| 9 |
+
predictable enough that we can choose where to spend backward/update compute.
|
| 10 |
+
|
| 11 |
+
What v4 adds
|
| 12 |
+
------------
|
| 13 |
+
A new policy:
|
| 14 |
+
|
| 15 |
+
predicted_magnitude
|
| 16 |
+
|
| 17 |
+
This selects active blocks using only a historical EMA of block gradient norms,
|
| 18 |
+
not the current gradient. This is much closer to something that could eventually
|
| 19 |
+
save backward computation, because the active set is known before using the
|
| 20 |
+
current full gradient.
|
| 21 |
+
|
| 22 |
+
Policies compared
|
| 23 |
+
-----------------
|
| 24 |
+
1. surprise
|
| 25 |
+
Select blocks whose current gradient is least predictable from EMA gradient.
|
| 26 |
+
This is mostly a diagnostic/oracle because it uses the current gradient.
|
| 27 |
+
|
| 28 |
+
2. magnitude
|
| 29 |
+
Select blocks with largest current gradient norm.
|
| 30 |
+
This is an oracle upper-bound for simple top-k block sparsification.
|
| 31 |
+
|
| 32 |
+
3. predicted_magnitude
|
| 33 |
+
Select blocks with largest EMA-predicted gradient norm.
|
| 34 |
+
This is the important practical test.
|
| 35 |
+
|
| 36 |
+
4. random
|
| 37 |
+
Control baseline.
|
| 38 |
+
|
| 39 |
+
Important caveat
|
| 40 |
+
----------------
|
| 41 |
+
This still computes the full gradient every step. That is intentional. We are
|
| 42 |
+
measuring whether the active-set prediction would have worked. Actual speedup
|
| 43 |
+
would require skipping or restricting backward computation for inactive blocks.
|
| 44 |
+
|
| 45 |
+
Run
|
| 46 |
+
---
|
| 47 |
+
python3 surprise_topk_gradient_prototype.py
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
from __future__ import annotations
|
| 51 |
+
|
| 52 |
+
import math
|
| 53 |
+
import random
|
| 54 |
+
from dataclasses import dataclass
|
| 55 |
+
from typing import Dict, List, Literal, Tuple
|
| 56 |
+
|
| 57 |
+
import torch
|
| 58 |
+
import torch.nn as nn
|
| 59 |
+
import torch.nn.functional as F
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
SEED = 7
|
| 63 |
+
random.seed(SEED)
|
| 64 |
+
torch.manual_seed(SEED)
|
| 65 |
+
|
| 66 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 67 |
+
Policy = Literal["surprise", "magnitude", "predicted_magnitude", "random"]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# -----------------------------
|
| 71 |
+
# Toy data: 2-class spiral
|
| 72 |
+
# -----------------------------
|
| 73 |
+
|
| 74 |
+
def make_spiral(n_per_class: int = 1024, noise: float = 0.12) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 75 |
+
xs = []
|
| 76 |
+
ys = []
|
| 77 |
+
|
| 78 |
+
for class_id in range(2):
|
| 79 |
+
r = torch.linspace(0.0, 1.0, n_per_class)
|
| 80 |
+
theta = class_id * math.pi + r * 4.0 * math.pi
|
| 81 |
+
theta = theta + torch.randn(n_per_class) * noise
|
| 82 |
+
|
| 83 |
+
x = torch.stack([r * torch.sin(theta), r * torch.cos(theta)], dim=1)
|
| 84 |
+
y = torch.full((n_per_class,), class_id, dtype=torch.long)
|
| 85 |
+
|
| 86 |
+
xs.append(x)
|
| 87 |
+
ys.append(y)
|
| 88 |
+
|
| 89 |
+
X = torch.cat(xs, dim=0)
|
| 90 |
+
Y = torch.cat(ys, dim=0)
|
| 91 |
+
X = 3.0 * X
|
| 92 |
+
|
| 93 |
+
perm = torch.randperm(X.shape[0])
|
| 94 |
+
return X[perm].to(DEVICE), Y[perm].to(DEVICE)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# -----------------------------
|
| 98 |
+
# Model
|
| 99 |
+
# -----------------------------
|
| 100 |
+
|
| 101 |
+
class TinyMLP(nn.Module):
|
| 102 |
+
def __init__(self, width: int = 128):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.net = nn.Sequential(
|
| 105 |
+
nn.Linear(2, width),
|
| 106 |
+
nn.ReLU(),
|
| 107 |
+
nn.Linear(width, width),
|
| 108 |
+
nn.ReLU(),
|
| 109 |
+
nn.Linear(width, width),
|
| 110 |
+
nn.ReLU(),
|
| 111 |
+
nn.Linear(width, 2),
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
return self.net(x)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def linear_layers(model: nn.Module) -> List[nn.Linear]:
|
| 119 |
+
return [m for m in model.modules() if isinstance(m, nn.Linear)]
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# -----------------------------
|
| 123 |
+
# Sparse gradient machinery
|
| 124 |
+
# -----------------------------
|
| 125 |
+
|
| 126 |
+
@dataclass(frozen=True)
|
| 127 |
+
class BlockRef:
|
| 128 |
+
layer_index: int
|
| 129 |
+
row_index: int
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class SparseGradientBuilder:
|
| 133 |
+
"""
|
| 134 |
+
Builds approximate gradients after a full backward pass.
|
| 135 |
+
|
| 136 |
+
Block = one output row of a Linear weight matrix, plus its bias element.
|
| 137 |
+
|
| 138 |
+
Approx gradient:
|
| 139 |
+
active blocks -> true gradient
|
| 140 |
+
inactive blocks -> EMA-predicted gradient
|
| 141 |
+
|
| 142 |
+
Selection policies:
|
| 143 |
+
surprise -> largest current residual: ||g - pred_g||
|
| 144 |
+
magnitude -> largest current gradient norm: ||g||
|
| 145 |
+
predicted_magnitude -> largest historical EMA gradient norm
|
| 146 |
+
random -> random active blocks
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
model: nn.Module,
|
| 152 |
+
policy: Policy = "predicted_magnitude",
|
| 153 |
+
grad_beta: float = 0.95,
|
| 154 |
+
mass_beta: float = 0.95,
|
| 155 |
+
active_fraction: float = 0.2,
|
| 156 |
+
refresh_interval: int = 10,
|
| 157 |
+
warmup_steps: int = 100,
|
| 158 |
+
eps: float = 1e-12,
|
| 159 |
+
):
|
| 160 |
+
self.model = model
|
| 161 |
+
self.layers = linear_layers(model)
|
| 162 |
+
self.policy = policy
|
| 163 |
+
self.grad_beta = grad_beta
|
| 164 |
+
self.mass_beta = mass_beta
|
| 165 |
+
self.active_fraction = active_fraction
|
| 166 |
+
self.refresh_interval = refresh_interval
|
| 167 |
+
self.warmup_steps = warmup_steps
|
| 168 |
+
self.eps = eps
|
| 169 |
+
|
| 170 |
+
self.blocks: List[BlockRef] = []
|
| 171 |
+
for li, layer in enumerate(self.layers):
|
| 172 |
+
for row in range(layer.weight.shape[0]):
|
| 173 |
+
self.blocks.append(BlockRef(li, row))
|
| 174 |
+
|
| 175 |
+
self.pred_w: Dict[int, torch.Tensor] = {}
|
| 176 |
+
self.pred_b: Dict[int, torch.Tensor] = {}
|
| 177 |
+
|
| 178 |
+
for li, layer in enumerate(self.layers):
|
| 179 |
+
self.pred_w[li] = torch.zeros_like(layer.weight.data)
|
| 180 |
+
if layer.bias is not None:
|
| 181 |
+
self.pred_b[li] = torch.zeros_like(layer.bias.data)
|
| 182 |
+
|
| 183 |
+
n = len(self.blocks)
|
| 184 |
+
self.current_gradient_mass = torch.ones(n, device=DEVICE)
|
| 185 |
+
self.current_residual_mass = torch.ones(n, device=DEVICE)
|
| 186 |
+
self.predicted_gradient_mass = torch.ones(n, device=DEVICE)
|
| 187 |
+
self.selection_score = torch.ones(n, device=DEVICE)
|
| 188 |
+
|
| 189 |
+
# Active-set stability diagnostic.
|
| 190 |
+
self.prev_active = torch.zeros(n, dtype=torch.bool, device=DEVICE)
|
| 191 |
+
|
| 192 |
+
def _is_refresh_step(self, step: int) -> bool:
|
| 193 |
+
return step < self.warmup_steps or step % self.refresh_interval == 0
|
| 194 |
+
|
| 195 |
+
def _choose_active_blocks(self, step: int) -> torch.Tensor:
|
| 196 |
+
n = len(self.blocks)
|
| 197 |
+
|
| 198 |
+
if self._is_refresh_step(step):
|
| 199 |
+
return torch.ones(n, dtype=torch.bool, device=DEVICE)
|
| 200 |
+
|
| 201 |
+
k = max(1, int(self.active_fraction * n))
|
| 202 |
+
active = torch.zeros(n, dtype=torch.bool, device=DEVICE)
|
| 203 |
+
|
| 204 |
+
if self.policy == "surprise":
|
| 205 |
+
# Uses last step's residual score when choosing before the current
|
| 206 |
+
# gradient is processed. After full gradient is computed below, this
|
| 207 |
+
# becomes an oracle-ish diagnostic of residual concentration.
|
| 208 |
+
idx = torch.topk(self.current_residual_mass, k=k).indices
|
| 209 |
+
elif self.policy == "magnitude":
|
| 210 |
+
# Uses last observed current_gradient_mass from previous step, not the
|
| 211 |
+
# current one, at selection time. Still a strong baseline.
|
| 212 |
+
idx = torch.topk(self.current_gradient_mass, k=k).indices
|
| 213 |
+
elif self.policy == "predicted_magnitude":
|
| 214 |
+
# This is the important practical policy: choose from historical EMA.
|
| 215 |
+
idx = torch.topk(self.predicted_gradient_mass, k=k).indices
|
| 216 |
+
elif self.policy == "random":
|
| 217 |
+
idx = torch.randperm(n, device=DEVICE)[:k]
|
| 218 |
+
else:
|
| 219 |
+
raise ValueError(f"Unknown policy: {self.policy}")
|
| 220 |
+
|
| 221 |
+
active[idx] = True
|
| 222 |
+
return active
|
| 223 |
+
|
| 224 |
+
@staticmethod
|
| 225 |
+
def _topk_mask(values: torch.Tensor, fraction: float) -> torch.Tensor:
|
| 226 |
+
n = values.numel()
|
| 227 |
+
k = max(1, int(fraction * n))
|
| 228 |
+
mask = torch.zeros(n, dtype=torch.bool, device=values.device)
|
| 229 |
+
mask[torch.topk(values, k=k).indices] = True
|
| 230 |
+
return mask
|
| 231 |
+
|
| 232 |
+
@staticmethod
|
| 233 |
+
def _jaccard(a: torch.Tensor, b: torch.Tensor) -> float:
|
| 234 |
+
inter = (a & b).sum().float()
|
| 235 |
+
union = (a | b).sum().float()
|
| 236 |
+
return float((inter / torch.clamp(union, min=1.0)).item())
|
| 237 |
+
|
| 238 |
+
@torch.no_grad()
|
| 239 |
+
def build_and_install_approx_grads(self, step: int) -> Dict[str, float]:
|
| 240 |
+
active = self._choose_active_blocks(step)
|
| 241 |
+
is_refresh = self._is_refresh_step(step)
|
| 242 |
+
|
| 243 |
+
true_parts = []
|
| 244 |
+
approx_parts = []
|
| 245 |
+
pred_parts = []
|
| 246 |
+
|
| 247 |
+
approx_w: Dict[int, torch.Tensor] = {}
|
| 248 |
+
approx_b: Dict[int, torch.Tensor] = {}
|
| 249 |
+
for li, layer in enumerate(self.layers):
|
| 250 |
+
approx_w[li] = torch.zeros_like(layer.weight.grad)
|
| 251 |
+
if layer.bias is not None:
|
| 252 |
+
approx_b[li] = torch.zeros_like(layer.bias.grad)
|
| 253 |
+
|
| 254 |
+
# Temporary arrays for current-step measurements.
|
| 255 |
+
new_gradient_mass = torch.zeros_like(self.current_gradient_mass)
|
| 256 |
+
new_residual_mass = torch.zeros_like(self.current_residual_mass)
|
| 257 |
+
|
| 258 |
+
for block_id, block in enumerate(self.blocks):
|
| 259 |
+
li = block.layer_index
|
| 260 |
+
row = block.row_index
|
| 261 |
+
layer = self.layers[li]
|
| 262 |
+
is_active = bool(active[block_id].item())
|
| 263 |
+
|
| 264 |
+
g_w = layer.weight.grad[row].detach().clone()
|
| 265 |
+
p_w = self.pred_w[li][row].detach().clone()
|
| 266 |
+
|
| 267 |
+
if layer.bias is not None:
|
| 268 |
+
g_b = layer.bias.grad[row].detach().clone()
|
| 269 |
+
p_b = self.pred_b[li][row].detach().clone()
|
| 270 |
+
else:
|
| 271 |
+
g_b = None
|
| 272 |
+
p_b = None
|
| 273 |
+
|
| 274 |
+
true_vec_items = [g_w.flatten()]
|
| 275 |
+
pred_vec_items = [p_w.flatten()]
|
| 276 |
+
if g_b is not None:
|
| 277 |
+
true_vec_items.append(g_b.view(1))
|
| 278 |
+
pred_vec_items.append(p_b.view(1))
|
| 279 |
+
|
| 280 |
+
true_vec_block = torch.cat(true_vec_items)
|
| 281 |
+
pred_vec_block = torch.cat(pred_vec_items)
|
| 282 |
+
residual_vec = true_vec_block - pred_vec_block
|
| 283 |
+
|
| 284 |
+
grad_norm = torch.norm(true_vec_block)
|
| 285 |
+
residual_norm = torch.norm(residual_vec)
|
| 286 |
+
|
| 287 |
+
new_gradient_mass[block_id] = grad_norm
|
| 288 |
+
new_residual_mass[block_id] = residual_norm
|
| 289 |
+
|
| 290 |
+
if is_active:
|
| 291 |
+
a_w = g_w
|
| 292 |
+
a_b = g_b
|
| 293 |
+
else:
|
| 294 |
+
a_w = p_w
|
| 295 |
+
a_b = p_b
|
| 296 |
+
|
| 297 |
+
approx_w[li][row] = a_w
|
| 298 |
+
if layer.bias is not None:
|
| 299 |
+
approx_b[li][row] = a_b
|
| 300 |
+
|
| 301 |
+
# Prototype choice: update predictors from true gradient because the
|
| 302 |
+
# full gradient is available for measurement. A speedup version would
|
| 303 |
+
# update fully only on refresh steps and partially on active blocks.
|
| 304 |
+
self.pred_w[li][row].mul_(self.grad_beta).add_(g_w, alpha=1.0 - self.grad_beta)
|
| 305 |
+
if layer.bias is not None:
|
| 306 |
+
self.pred_b[li][row].mul_(self.grad_beta).add_(g_b, alpha=1.0 - self.grad_beta)
|
| 307 |
+
|
| 308 |
+
approx_vec_items = [a_w.flatten()]
|
| 309 |
+
if a_b is not None:
|
| 310 |
+
approx_vec_items.append(a_b.view(1))
|
| 311 |
+
|
| 312 |
+
true_parts.append(true_vec_block)
|
| 313 |
+
pred_parts.append(pred_vec_block)
|
| 314 |
+
approx_parts.append(torch.cat(approx_vec_items))
|
| 315 |
+
|
| 316 |
+
# Install approximate gradients for Adam.
|
| 317 |
+
for li, layer in enumerate(self.layers):
|
| 318 |
+
layer.weight.grad.copy_(approx_w[li])
|
| 319 |
+
if layer.bias is not None:
|
| 320 |
+
layer.bias.grad.copy_(approx_b[li])
|
| 321 |
+
|
| 322 |
+
true_vec = torch.cat(true_parts)
|
| 323 |
+
pred_vec = torch.cat(pred_parts)
|
| 324 |
+
approx_vec = torch.cat(approx_parts)
|
| 325 |
+
|
| 326 |
+
true_norm = torch.norm(true_vec)
|
| 327 |
+
pred_error_norm = torch.norm(true_vec - pred_vec)
|
| 328 |
+
|
| 329 |
+
cosine = F.cosine_similarity(true_vec, approx_vec, dim=0).item()
|
| 330 |
+
pred_explained = 1.0 - (
|
| 331 |
+
pred_error_norm.pow(2) / (true_norm.pow(2) + self.eps)
|
| 332 |
+
).item()
|
| 333 |
+
|
| 334 |
+
# Oracle masks for diagnostics after seeing the true current gradient.
|
| 335 |
+
oracle_magnitude_mask = self._topk_mask(new_gradient_mass, self.active_fraction)
|
| 336 |
+
oracle_residual_mask = self._topk_mask(new_residual_mass, self.active_fraction)
|
| 337 |
+
predicted_magnitude_mask = self._topk_mask(self.predicted_gradient_mass, self.active_fraction)
|
| 338 |
+
|
| 339 |
+
active_vs_oracle_mag = self._jaccard(active, oracle_magnitude_mask)
|
| 340 |
+
active_vs_oracle_resid = self._jaccard(active, oracle_residual_mask)
|
| 341 |
+
predmag_vs_oracle_mag = self._jaccard(predicted_magnitude_mask, oracle_magnitude_mask)
|
| 342 |
+
active_stability = self._jaccard(active, self.prev_active)
|
| 343 |
+
|
| 344 |
+
self.prev_active = active.clone()
|
| 345 |
+
|
| 346 |
+
k20 = max(1, int(0.2 * len(self.blocks)))
|
| 347 |
+
|
| 348 |
+
sorted_residual = torch.sort(new_residual_mass.detach(), descending=True).values
|
| 349 |
+
top20_residual_mass = (sorted_residual[:k20].sum() / (sorted_residual.sum() + self.eps)).item()
|
| 350 |
+
|
| 351 |
+
sorted_gradient = torch.sort(new_gradient_mass.detach(), descending=True).values
|
| 352 |
+
top20_gradient_mass = (sorted_gradient[:k20].sum() / (sorted_gradient.sum() + self.eps)).item()
|
| 353 |
+
|
| 354 |
+
# Update mass trackers AFTER diagnostics, so predicted_magnitude really
|
| 355 |
+
# uses only history at selection time.
|
| 356 |
+
self.current_gradient_mass = new_gradient_mass
|
| 357 |
+
self.current_residual_mass = new_residual_mass
|
| 358 |
+
self.predicted_gradient_mass.mul_(self.mass_beta).add_(new_gradient_mass, alpha=1.0 - self.mass_beta)
|
| 359 |
+
|
| 360 |
+
return {
|
| 361 |
+
"is_refresh": float(is_refresh),
|
| 362 |
+
"active_fraction": float(active.float().mean().item()),
|
| 363 |
+
"cosine_true_vs_approx": cosine,
|
| 364 |
+
"pred_explained_fraction": pred_explained,
|
| 365 |
+
"top20_residual_mass": top20_residual_mass,
|
| 366 |
+
"top20_gradient_mass": top20_gradient_mass,
|
| 367 |
+
"active_vs_oracle_mag": active_vs_oracle_mag,
|
| 368 |
+
"active_vs_oracle_resid": active_vs_oracle_resid,
|
| 369 |
+
"predmag_vs_oracle_mag": predmag_vs_oracle_mag,
|
| 370 |
+
"active_stability": active_stability,
|
| 371 |
+
"true_grad_norm": float(true_norm.item()),
|
| 372 |
+
"pred_error_norm": float(pred_error_norm.item()),
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
# -----------------------------
|
| 377 |
+
# Metrics and training
|
| 378 |
+
# -----------------------------
|
| 379 |
+
|
| 380 |
+
def accuracy(model: nn.Module, X: torch.Tensor, y: torch.Tensor) -> float:
|
| 381 |
+
model.eval()
|
| 382 |
+
with torch.no_grad():
|
| 383 |
+
pred = model(X).argmax(dim=1)
|
| 384 |
+
return (pred == y).float().mean().item()
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def train_baseline(
|
| 388 |
+
X: torch.Tensor,
|
| 389 |
+
y: torch.Tensor,
|
| 390 |
+
steps: int = 2000,
|
| 391 |
+
batch_size: int = 256,
|
| 392 |
+
lr: float = 1e-3,
|
| 393 |
+
) -> Tuple[nn.Module, List[Dict[str, float]]]:
|
| 394 |
+
model = TinyMLP().to(DEVICE)
|
| 395 |
+
opt = torch.optim.Adam(model.parameters(), lr=lr)
|
| 396 |
+
history: List[Dict[str, float]] = []
|
| 397 |
+
|
| 398 |
+
for step in range(steps):
|
| 399 |
+
model.train()
|
| 400 |
+
idx = torch.randint(0, X.shape[0], (batch_size,), device=DEVICE)
|
| 401 |
+
xb, yb = X[idx], y[idx]
|
| 402 |
+
|
| 403 |
+
loss = F.cross_entropy(model(xb), yb)
|
| 404 |
+
|
| 405 |
+
opt.zero_grad(set_to_none=True)
|
| 406 |
+
loss.backward()
|
| 407 |
+
opt.step()
|
| 408 |
+
|
| 409 |
+
if step % 97 == 0 or step == steps - 1:
|
| 410 |
+
history.append({
|
| 411 |
+
"step": step,
|
| 412 |
+
"loss": float(loss.item()),
|
| 413 |
+
"accuracy": accuracy(model, X, y),
|
| 414 |
+
})
|
| 415 |
+
|
| 416 |
+
return model, history
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def train_sparse_policy(
|
| 420 |
+
X: torch.Tensor,
|
| 421 |
+
y: torch.Tensor,
|
| 422 |
+
policy: Policy,
|
| 423 |
+
steps: int = 2000,
|
| 424 |
+
batch_size: int = 256,
|
| 425 |
+
lr: float = 1e-3,
|
| 426 |
+
active_fraction: float = 0.2,
|
| 427 |
+
refresh_interval: int = 10,
|
| 428 |
+
warmup_steps: int = 100,
|
| 429 |
+
grad_beta: float = 0.95,
|
| 430 |
+
mass_beta: float = 0.95,
|
| 431 |
+
) -> Tuple[nn.Module, List[Dict[str, float]]]:
|
| 432 |
+
model = TinyMLP().to(DEVICE)
|
| 433 |
+
opt = torch.optim.Adam(model.parameters(), lr=lr)
|
| 434 |
+
builder = SparseGradientBuilder(
|
| 435 |
+
model,
|
| 436 |
+
policy=policy,
|
| 437 |
+
grad_beta=grad_beta,
|
| 438 |
+
mass_beta=mass_beta,
|
| 439 |
+
active_fraction=active_fraction,
|
| 440 |
+
refresh_interval=refresh_interval,
|
| 441 |
+
warmup_steps=warmup_steps,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
history: List[Dict[str, float]] = []
|
| 445 |
+
|
| 446 |
+
for step in range(steps):
|
| 447 |
+
model.train()
|
| 448 |
+
idx = torch.randint(0, X.shape[0], (batch_size,), device=DEVICE)
|
| 449 |
+
xb, yb = X[idx], y[idx]
|
| 450 |
+
|
| 451 |
+
loss = F.cross_entropy(model(xb), yb)
|
| 452 |
+
|
| 453 |
+
opt.zero_grad(set_to_none=True)
|
| 454 |
+
loss.backward()
|
| 455 |
+
|
| 456 |
+
diagnostics = builder.build_and_install_approx_grads(step)
|
| 457 |
+
opt.step()
|
| 458 |
+
|
| 459 |
+
if step % 97 == 0 or step == steps - 1:
|
| 460 |
+
history.append({
|
| 461 |
+
"step": step,
|
| 462 |
+
"loss": float(loss.item()),
|
| 463 |
+
"accuracy": accuracy(model, X, y),
|
| 464 |
+
**diagnostics,
|
| 465 |
+
})
|
| 466 |
+
|
| 467 |
+
return model, history
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def sparse_rows(history: List[Dict[str, float]]) -> List[Dict[str, float]]:
|
| 471 |
+
return [row for row in history if row.get("is_refresh", 0.0) == 0.0]
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def avg_sparse_metric(history: List[Dict[str, float]], key: str) -> float:
|
| 475 |
+
rows = sparse_rows(history)
|
| 476 |
+
vals = [row[key] for row in rows]
|
| 477 |
+
if not vals:
|
| 478 |
+
return float("nan")
|
| 479 |
+
return sum(vals) / len(vals)
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def print_last(label: str, history: List[Dict[str, float]]) -> None:
|
| 483 |
+
print(f"\n{label}")
|
| 484 |
+
for k, v in history[-1].items():
|
| 485 |
+
if isinstance(v, float):
|
| 486 |
+
print(f" {k:28s}: {v:.4f}")
|
| 487 |
+
else:
|
| 488 |
+
print(f" {k:28s}: {v}")
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def print_sparse_summary(label: str, history: List[Dict[str, float]]) -> None:
|
| 492 |
+
last = history[-1]
|
| 493 |
+
print(f"\n{label} sparse-step averages from logged checkpoints")
|
| 494 |
+
for key in [
|
| 495 |
+
"cosine_true_vs_approx",
|
| 496 |
+
"pred_explained_fraction",
|
| 497 |
+
"top20_residual_mass",
|
| 498 |
+
"top20_gradient_mass",
|
| 499 |
+
"active_vs_oracle_mag",
|
| 500 |
+
"active_vs_oracle_resid",
|
| 501 |
+
"predmag_vs_oracle_mag",
|
| 502 |
+
"active_stability",
|
| 503 |
+
]:
|
| 504 |
+
print(f" avg {key:24s}: {avg_sparse_metric(history, key):.4f}")
|
| 505 |
+
print(f" final accuracy : {last['accuracy']:.4f}")
|
| 506 |
+
print(f" final loss : {last['loss']:.4f}")
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def print_checkpoints(label: str, history: List[Dict[str, float]]) -> None:
|
| 510 |
+
print(f"\n{label} checkpoints:")
|
| 511 |
+
stride = max(1, len(history) // 8)
|
| 512 |
+
for row in history[::stride]:
|
| 513 |
+
extra = ""
|
| 514 |
+
if "cosine_true_vs_approx" in row:
|
| 515 |
+
extra = (
|
| 516 |
+
f" refresh={int(row['is_refresh'])}"
|
| 517 |
+
f" active={row['active_fraction']:.2f}"
|
| 518 |
+
f" cos={row['cosine_true_vs_approx']:.3f}"
|
| 519 |
+
f" pred_expl={row['pred_explained_fraction']:.3f}"
|
| 520 |
+
f" top20_grad={row['top20_gradient_mass']:.3f}"
|
| 521 |
+
f" jacc_mag={row['active_vs_oracle_mag']:.3f}"
|
| 522 |
+
f" stable={row['active_stability']:.3f}"
|
| 523 |
+
)
|
| 524 |
+
print(
|
| 525 |
+
f"step={row['step']:4d} "
|
| 526 |
+
f"loss={row['loss']:.4f} "
|
| 527 |
+
f"acc={row['accuracy']:.3f}"
|
| 528 |
+
f"{extra}"
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
def main() -> None:
|
| 533 |
+
X, y = make_spiral()
|
| 534 |
+
|
| 535 |
+
baseline_model, baseline_hist = train_baseline(X, y)
|
| 536 |
+
|
| 537 |
+
results = {}
|
| 538 |
+
for policy in ["surprise", "magnitude", "predicted_magnitude", "random"]:
|
| 539 |
+
_, hist = train_sparse_policy(
|
| 540 |
+
X,
|
| 541 |
+
y,
|
| 542 |
+
policy=policy,
|
| 543 |
+
active_fraction=0.2,
|
| 544 |
+
refresh_interval=10,
|
| 545 |
+
warmup_steps=100,
|
| 546 |
+
grad_beta=0.95,
|
| 547 |
+
mass_beta=0.95,
|
| 548 |
+
)
|
| 549 |
+
results[policy] = hist
|
| 550 |
+
|
| 551 |
+
print_last("Baseline full Adam", baseline_hist)
|
| 552 |
+
for policy, hist in results.items():
|
| 553 |
+
print_last(f"{policy.title().replace('_', ' ')} Top-K simulated Adam", hist)
|
| 554 |
+
|
| 555 |
+
print_checkpoints("Baseline", baseline_hist)
|
| 556 |
+
for policy, hist in results.items():
|
| 557 |
+
print_checkpoints(f"{policy.title().replace('_', ' ')} Top-K", hist)
|
| 558 |
+
print_sparse_summary(f"{policy.title().replace('_', ' ')} Top-K", hist)
|
| 559 |
+
|
| 560 |
+
print("\nHow to read this:")
|
| 561 |
+
print(" predicted_magnitude => the practical policy to watch")
|
| 562 |
+
print(" magnitude => oracle-ish upper bound using recent/current mass")
|
| 563 |
+
print(" cos near 1.0 => approximate update points like the true gradient")
|
| 564 |
+
print(" top20_grad high => raw gradient mass is heavy-tailed/concentrated")
|
| 565 |
+
print(" jacc_mag high => selected blocks match current oracle top-k blocks")
|
| 566 |
+
print(" stable high => active set is stable over time")
|
| 567 |
+
print(" pred_mag close to mag => historical mass is enough to select useful blocks")
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
if __name__ == "__main__":
|
| 571 |
+
main()
|
experiments/surprise_topk_gradient_prototype-v5.py
ADDED
|
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Predicted-Magnitude Sparse Gradient Prototype, v5
|
| 3 |
+
|
| 4 |
+
Goal
|
| 5 |
+
----
|
| 6 |
+
Test the practical version of the hypothesis more harshly:
|
| 7 |
+
|
| 8 |
+
The gradient/update signal is heavy-tailed, and the high-mass blocks are
|
| 9 |
+
predictable enough that we can update only those blocks most of the time.
|
| 10 |
+
|
| 11 |
+
v5 adds
|
| 12 |
+
-------
|
| 13 |
+
1. inactive_mode="ema"
|
| 14 |
+
Inactive blocks receive the EMA-predicted gradient.
|
| 15 |
+
|
| 16 |
+
2. inactive_mode="zero"
|
| 17 |
+
Inactive blocks receive zero gradient. This is the stricter and more
|
| 18 |
+
compute-relevant test.
|
| 19 |
+
|
| 20 |
+
3. active_fraction sweep
|
| 21 |
+
Tests 20%, 10%, 5%, and 2% active blocks.
|
| 22 |
+
|
| 23 |
+
4. focused policy comparison
|
| 24 |
+
- predicted_magnitude: practical policy, chooses active blocks from history
|
| 25 |
+
- magnitude: oracle-ish policy, chooses from recently observed mass
|
| 26 |
+
- random: control
|
| 27 |
+
|
| 28 |
+
Important caveat
|
| 29 |
+
----------------
|
| 30 |
+
This still computes the full gradient every step. That is intentional. We are
|
| 31 |
+
measuring whether the selected active set would have preserved useful learning.
|
| 32 |
+
Actual speedup would require restricting/skipping backward computation for
|
| 33 |
+
inactive blocks with custom structured backward logic.
|
| 34 |
+
|
| 35 |
+
Run
|
| 36 |
+
---
|
| 37 |
+
python3 surprise_topk_gradient_prototype.py
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
from __future__ import annotations
|
| 41 |
+
|
| 42 |
+
import math
|
| 43 |
+
import random
|
| 44 |
+
from dataclasses import dataclass
|
| 45 |
+
from typing import Dict, List, Literal, Tuple
|
| 46 |
+
|
| 47 |
+
import torch
|
| 48 |
+
import torch.nn as nn
|
| 49 |
+
import torch.nn.functional as F
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
SEED = 7
|
| 53 |
+
random.seed(SEED)
|
| 54 |
+
torch.manual_seed(SEED)
|
| 55 |
+
|
| 56 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 57 |
+
Policy = Literal["magnitude", "predicted_magnitude", "random"]
|
| 58 |
+
InactiveMode = Literal["ema", "zero"]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# -----------------------------
|
| 62 |
+
# Toy data: 2-class spiral
|
| 63 |
+
# -----------------------------
|
| 64 |
+
|
| 65 |
+
def make_spiral(n_per_class: int = 1024, noise: float = 0.12) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 66 |
+
xs = []
|
| 67 |
+
ys = []
|
| 68 |
+
|
| 69 |
+
for class_id in range(2):
|
| 70 |
+
r = torch.linspace(0.0, 1.0, n_per_class)
|
| 71 |
+
theta = class_id * math.pi + r * 4.0 * math.pi
|
| 72 |
+
theta = theta + torch.randn(n_per_class) * noise
|
| 73 |
+
|
| 74 |
+
x = torch.stack([r * torch.sin(theta), r * torch.cos(theta)], dim=1)
|
| 75 |
+
y = torch.full((n_per_class,), class_id, dtype=torch.long)
|
| 76 |
+
|
| 77 |
+
xs.append(x)
|
| 78 |
+
ys.append(y)
|
| 79 |
+
|
| 80 |
+
X = torch.cat(xs, dim=0)
|
| 81 |
+
Y = torch.cat(ys, dim=0)
|
| 82 |
+
X = 3.0 * X
|
| 83 |
+
|
| 84 |
+
perm = torch.randperm(X.shape[0])
|
| 85 |
+
return X[perm].to(DEVICE), Y[perm].to(DEVICE)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# -----------------------------
|
| 89 |
+
# Model
|
| 90 |
+
# -----------------------------
|
| 91 |
+
|
| 92 |
+
class TinyMLP(nn.Module):
|
| 93 |
+
def __init__(self, width: int = 128):
|
| 94 |
+
super().__init__()
|
| 95 |
+
self.net = nn.Sequential(
|
| 96 |
+
nn.Linear(2, width),
|
| 97 |
+
nn.ReLU(),
|
| 98 |
+
nn.Linear(width, width),
|
| 99 |
+
nn.ReLU(),
|
| 100 |
+
nn.Linear(width, width),
|
| 101 |
+
nn.ReLU(),
|
| 102 |
+
nn.Linear(width, 2),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 106 |
+
return self.net(x)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def linear_layers(model: nn.Module) -> List[nn.Linear]:
|
| 110 |
+
return [m for m in model.modules() if isinstance(m, nn.Linear)]
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# -----------------------------
|
| 114 |
+
# Sparse gradient machinery
|
| 115 |
+
# -----------------------------
|
| 116 |
+
|
| 117 |
+
@dataclass(frozen=True)
|
| 118 |
+
class BlockRef:
|
| 119 |
+
layer_index: int
|
| 120 |
+
row_index: int
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class SparseGradientBuilder:
|
| 124 |
+
"""
|
| 125 |
+
Builds approximate gradients after a full backward pass.
|
| 126 |
+
|
| 127 |
+
Block = one output row of a Linear weight matrix, plus its bias element.
|
| 128 |
+
|
| 129 |
+
Approx gradient:
|
| 130 |
+
active blocks -> true gradient
|
| 131 |
+
inactive blocks -> either EMA-predicted gradient or zero gradient
|
| 132 |
+
|
| 133 |
+
Selection policies:
|
| 134 |
+
magnitude -> largest recently observed gradient norm
|
| 135 |
+
predicted_magnitude -> largest historical EMA gradient norm
|
| 136 |
+
random -> random active blocks
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def __init__(
|
| 140 |
+
self,
|
| 141 |
+
model: nn.Module,
|
| 142 |
+
policy: Policy = "predicted_magnitude",
|
| 143 |
+
inactive_mode: InactiveMode = "zero",
|
| 144 |
+
grad_beta: float = 0.95,
|
| 145 |
+
mass_beta: float = 0.95,
|
| 146 |
+
active_fraction: float = 0.2,
|
| 147 |
+
refresh_interval: int = 10,
|
| 148 |
+
warmup_steps: int = 100,
|
| 149 |
+
eps: float = 1e-12,
|
| 150 |
+
):
|
| 151 |
+
self.model = model
|
| 152 |
+
self.layers = linear_layers(model)
|
| 153 |
+
self.policy = policy
|
| 154 |
+
self.inactive_mode = inactive_mode
|
| 155 |
+
self.grad_beta = grad_beta
|
| 156 |
+
self.mass_beta = mass_beta
|
| 157 |
+
self.active_fraction = active_fraction
|
| 158 |
+
self.refresh_interval = refresh_interval
|
| 159 |
+
self.warmup_steps = warmup_steps
|
| 160 |
+
self.eps = eps
|
| 161 |
+
|
| 162 |
+
self.blocks: List[BlockRef] = []
|
| 163 |
+
for li, layer in enumerate(self.layers):
|
| 164 |
+
for row in range(layer.weight.shape[0]):
|
| 165 |
+
self.blocks.append(BlockRef(li, row))
|
| 166 |
+
|
| 167 |
+
self.pred_w: Dict[int, torch.Tensor] = {}
|
| 168 |
+
self.pred_b: Dict[int, torch.Tensor] = {}
|
| 169 |
+
|
| 170 |
+
for li, layer in enumerate(self.layers):
|
| 171 |
+
self.pred_w[li] = torch.zeros_like(layer.weight.data)
|
| 172 |
+
if layer.bias is not None:
|
| 173 |
+
self.pred_b[li] = torch.zeros_like(layer.bias.data)
|
| 174 |
+
|
| 175 |
+
n = len(self.blocks)
|
| 176 |
+
self.current_gradient_mass = torch.ones(n, device=DEVICE)
|
| 177 |
+
self.predicted_gradient_mass = torch.ones(n, device=DEVICE)
|
| 178 |
+
self.prev_active = torch.zeros(n, dtype=torch.bool, device=DEVICE)
|
| 179 |
+
|
| 180 |
+
def _is_refresh_step(self, step: int) -> bool:
|
| 181 |
+
return step < self.warmup_steps or step % self.refresh_interval == 0
|
| 182 |
+
|
| 183 |
+
def _choose_active_blocks(self, step: int) -> torch.Tensor:
|
| 184 |
+
n = len(self.blocks)
|
| 185 |
+
|
| 186 |
+
if self._is_refresh_step(step):
|
| 187 |
+
return torch.ones(n, dtype=torch.bool, device=DEVICE)
|
| 188 |
+
|
| 189 |
+
k = max(1, int(self.active_fraction * n))
|
| 190 |
+
active = torch.zeros(n, dtype=torch.bool, device=DEVICE)
|
| 191 |
+
|
| 192 |
+
if self.policy == "magnitude":
|
| 193 |
+
# Uses the previous step's observed gradient mass, not the current one.
|
| 194 |
+
idx = torch.topk(self.current_gradient_mass, k=k).indices
|
| 195 |
+
elif self.policy == "predicted_magnitude":
|
| 196 |
+
# Practical policy: uses historical EMA only.
|
| 197 |
+
idx = torch.topk(self.predicted_gradient_mass, k=k).indices
|
| 198 |
+
elif self.policy == "random":
|
| 199 |
+
idx = torch.randperm(n, device=DEVICE)[:k]
|
| 200 |
+
else:
|
| 201 |
+
raise ValueError(f"Unknown policy: {self.policy}")
|
| 202 |
+
|
| 203 |
+
active[idx] = True
|
| 204 |
+
return active
|
| 205 |
+
|
| 206 |
+
@staticmethod
|
| 207 |
+
def _topk_mask(values: torch.Tensor, fraction: float) -> torch.Tensor:
|
| 208 |
+
n = values.numel()
|
| 209 |
+
k = max(1, int(fraction * n))
|
| 210 |
+
mask = torch.zeros(n, dtype=torch.bool, device=values.device)
|
| 211 |
+
mask[torch.topk(values, k=k).indices] = True
|
| 212 |
+
return mask
|
| 213 |
+
|
| 214 |
+
@staticmethod
|
| 215 |
+
def _jaccard(a: torch.Tensor, b: torch.Tensor) -> float:
|
| 216 |
+
inter = (a & b).sum().float()
|
| 217 |
+
union = (a | b).sum().float()
|
| 218 |
+
return float((inter / torch.clamp(union, min=1.0)).item())
|
| 219 |
+
|
| 220 |
+
@torch.no_grad()
|
| 221 |
+
def build_and_install_approx_grads(self, step: int) -> Dict[str, float]:
|
| 222 |
+
active = self._choose_active_blocks(step)
|
| 223 |
+
is_refresh = self._is_refresh_step(step)
|
| 224 |
+
|
| 225 |
+
true_parts = []
|
| 226 |
+
approx_parts = []
|
| 227 |
+
pred_parts = []
|
| 228 |
+
|
| 229 |
+
approx_w: Dict[int, torch.Tensor] = {}
|
| 230 |
+
approx_b: Dict[int, torch.Tensor] = {}
|
| 231 |
+
for li, layer in enumerate(self.layers):
|
| 232 |
+
approx_w[li] = torch.zeros_like(layer.weight.grad)
|
| 233 |
+
if layer.bias is not None:
|
| 234 |
+
approx_b[li] = torch.zeros_like(layer.bias.grad)
|
| 235 |
+
|
| 236 |
+
new_gradient_mass = torch.zeros_like(self.current_gradient_mass)
|
| 237 |
+
|
| 238 |
+
for block_id, block in enumerate(self.blocks):
|
| 239 |
+
li = block.layer_index
|
| 240 |
+
row = block.row_index
|
| 241 |
+
layer = self.layers[li]
|
| 242 |
+
is_active = bool(active[block_id].item())
|
| 243 |
+
|
| 244 |
+
g_w = layer.weight.grad[row].detach().clone()
|
| 245 |
+
p_w = self.pred_w[li][row].detach().clone()
|
| 246 |
+
|
| 247 |
+
if layer.bias is not None:
|
| 248 |
+
g_b = layer.bias.grad[row].detach().clone()
|
| 249 |
+
p_b = self.pred_b[li][row].detach().clone()
|
| 250 |
+
else:
|
| 251 |
+
g_b = None
|
| 252 |
+
p_b = None
|
| 253 |
+
|
| 254 |
+
true_vec_items = [g_w.flatten()]
|
| 255 |
+
pred_vec_items = [p_w.flatten()]
|
| 256 |
+
if g_b is not None:
|
| 257 |
+
true_vec_items.append(g_b.view(1))
|
| 258 |
+
pred_vec_items.append(p_b.view(1))
|
| 259 |
+
|
| 260 |
+
true_vec_block = torch.cat(true_vec_items)
|
| 261 |
+
pred_vec_block = torch.cat(pred_vec_items)
|
| 262 |
+
grad_norm = torch.norm(true_vec_block)
|
| 263 |
+
new_gradient_mass[block_id] = grad_norm
|
| 264 |
+
|
| 265 |
+
if is_active:
|
| 266 |
+
a_w = g_w
|
| 267 |
+
a_b = g_b
|
| 268 |
+
else:
|
| 269 |
+
if self.inactive_mode == "ema":
|
| 270 |
+
a_w = p_w
|
| 271 |
+
a_b = p_b
|
| 272 |
+
elif self.inactive_mode == "zero":
|
| 273 |
+
a_w = torch.zeros_like(g_w)
|
| 274 |
+
a_b = torch.zeros_like(g_b) if g_b is not None else None
|
| 275 |
+
else:
|
| 276 |
+
raise ValueError(f"Unknown inactive_mode: {self.inactive_mode}")
|
| 277 |
+
|
| 278 |
+
approx_w[li][row] = a_w
|
| 279 |
+
if layer.bias is not None:
|
| 280 |
+
approx_b[li][row] = a_b
|
| 281 |
+
|
| 282 |
+
# Prototype choice: update predictors from true gradient because the
|
| 283 |
+
# full gradient is available for measurement. A speedup version would
|
| 284 |
+
# update predictors fully only on refresh steps and partially on active blocks.
|
| 285 |
+
self.pred_w[li][row].mul_(self.grad_beta).add_(g_w, alpha=1.0 - self.grad_beta)
|
| 286 |
+
if layer.bias is not None:
|
| 287 |
+
self.pred_b[li][row].mul_(self.grad_beta).add_(g_b, alpha=1.0 - self.grad_beta)
|
| 288 |
+
|
| 289 |
+
approx_vec_items = [a_w.flatten()]
|
| 290 |
+
if a_b is not None:
|
| 291 |
+
approx_vec_items.append(a_b.view(1))
|
| 292 |
+
|
| 293 |
+
true_parts.append(true_vec_block)
|
| 294 |
+
pred_parts.append(pred_vec_block)
|
| 295 |
+
approx_parts.append(torch.cat(approx_vec_items))
|
| 296 |
+
|
| 297 |
+
# Install approximate gradients for Adam.
|
| 298 |
+
for li, layer in enumerate(self.layers):
|
| 299 |
+
layer.weight.grad.copy_(approx_w[li])
|
| 300 |
+
if layer.bias is not None:
|
| 301 |
+
layer.bias.grad.copy_(approx_b[li])
|
| 302 |
+
|
| 303 |
+
true_vec = torch.cat(true_parts)
|
| 304 |
+
pred_vec = torch.cat(pred_parts)
|
| 305 |
+
approx_vec = torch.cat(approx_parts)
|
| 306 |
+
|
| 307 |
+
true_norm = torch.norm(true_vec)
|
| 308 |
+
pred_error_norm = torch.norm(true_vec - pred_vec)
|
| 309 |
+
|
| 310 |
+
cosine = F.cosine_similarity(true_vec, approx_vec, dim=0).item()
|
| 311 |
+
approx_norm_ratio = float((torch.norm(approx_vec) / (true_norm + self.eps)).item())
|
| 312 |
+
pred_explained = 1.0 - (
|
| 313 |
+
pred_error_norm.pow(2) / (true_norm.pow(2) + self.eps)
|
| 314 |
+
).item()
|
| 315 |
+
|
| 316 |
+
oracle_magnitude_mask = self._topk_mask(new_gradient_mass, self.active_fraction)
|
| 317 |
+
predicted_magnitude_mask = self._topk_mask(self.predicted_gradient_mass, self.active_fraction)
|
| 318 |
+
|
| 319 |
+
active_vs_oracle_mag = self._jaccard(active, oracle_magnitude_mask)
|
| 320 |
+
predmag_vs_oracle_mag = self._jaccard(predicted_magnitude_mask, oracle_magnitude_mask)
|
| 321 |
+
active_stability = self._jaccard(active, self.prev_active)
|
| 322 |
+
self.prev_active = active.clone()
|
| 323 |
+
|
| 324 |
+
k20 = max(1, int(0.2 * len(self.blocks)))
|
| 325 |
+
sorted_gradient = torch.sort(new_gradient_mass.detach(), descending=True).values
|
| 326 |
+
top20_gradient_mass = (sorted_gradient[:k20].sum() / (sorted_gradient.sum() + self.eps)).item()
|
| 327 |
+
|
| 328 |
+
# Update mass trackers AFTER diagnostics, so predicted_magnitude really
|
| 329 |
+
# uses only history at selection time.
|
| 330 |
+
self.current_gradient_mass = new_gradient_mass
|
| 331 |
+
self.predicted_gradient_mass.mul_(self.mass_beta).add_(new_gradient_mass, alpha=1.0 - self.mass_beta)
|
| 332 |
+
|
| 333 |
+
return {
|
| 334 |
+
"is_refresh": float(is_refresh),
|
| 335 |
+
"active_fraction": float(active.float().mean().item()),
|
| 336 |
+
"cosine_true_vs_approx": cosine,
|
| 337 |
+
"approx_norm_ratio": approx_norm_ratio,
|
| 338 |
+
"pred_explained_fraction": pred_explained,
|
| 339 |
+
"top20_gradient_mass": top20_gradient_mass,
|
| 340 |
+
"active_vs_oracle_mag": active_vs_oracle_mag,
|
| 341 |
+
"predmag_vs_oracle_mag": predmag_vs_oracle_mag,
|
| 342 |
+
"active_stability": active_stability,
|
| 343 |
+
"true_grad_norm": float(true_norm.item()),
|
| 344 |
+
"pred_error_norm": float(pred_error_norm.item()),
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
# -----------------------------
|
| 349 |
+
# Metrics and training
|
| 350 |
+
# -----------------------------
|
| 351 |
+
|
| 352 |
+
def accuracy(model: nn.Module, X: torch.Tensor, y: torch.Tensor) -> float:
|
| 353 |
+
model.eval()
|
| 354 |
+
with torch.no_grad():
|
| 355 |
+
pred = model(X).argmax(dim=1)
|
| 356 |
+
return (pred == y).float().mean().item()
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def train_baseline(
|
| 360 |
+
X: torch.Tensor,
|
| 361 |
+
y: torch.Tensor,
|
| 362 |
+
steps: int = 2000,
|
| 363 |
+
batch_size: int = 256,
|
| 364 |
+
lr: float = 1e-3,
|
| 365 |
+
) -> Tuple[nn.Module, List[Dict[str, float]]]:
|
| 366 |
+
model = TinyMLP().to(DEVICE)
|
| 367 |
+
opt = torch.optim.Adam(model.parameters(), lr=lr)
|
| 368 |
+
history: List[Dict[str, float]] = []
|
| 369 |
+
|
| 370 |
+
for step in range(steps):
|
| 371 |
+
model.train()
|
| 372 |
+
idx = torch.randint(0, X.shape[0], (batch_size,), device=DEVICE)
|
| 373 |
+
xb, yb = X[idx], y[idx]
|
| 374 |
+
|
| 375 |
+
loss = F.cross_entropy(model(xb), yb)
|
| 376 |
+
|
| 377 |
+
opt.zero_grad(set_to_none=True)
|
| 378 |
+
loss.backward()
|
| 379 |
+
opt.step()
|
| 380 |
+
|
| 381 |
+
if step % 97 == 0 or step == steps - 1:
|
| 382 |
+
history.append({
|
| 383 |
+
"step": step,
|
| 384 |
+
"loss": float(loss.item()),
|
| 385 |
+
"accuracy": accuracy(model, X, y),
|
| 386 |
+
})
|
| 387 |
+
|
| 388 |
+
return model, history
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def train_sparse_policy(
|
| 392 |
+
X: torch.Tensor,
|
| 393 |
+
y: torch.Tensor,
|
| 394 |
+
policy: Policy,
|
| 395 |
+
inactive_mode: InactiveMode,
|
| 396 |
+
active_fraction: float,
|
| 397 |
+
steps: int = 2000,
|
| 398 |
+
batch_size: int = 256,
|
| 399 |
+
lr: float = 1e-3,
|
| 400 |
+
refresh_interval: int = 10,
|
| 401 |
+
warmup_steps: int = 100,
|
| 402 |
+
grad_beta: float = 0.95,
|
| 403 |
+
mass_beta: float = 0.95,
|
| 404 |
+
) -> Tuple[nn.Module, List[Dict[str, float]]]:
|
| 405 |
+
model = TinyMLP().to(DEVICE)
|
| 406 |
+
opt = torch.optim.Adam(model.parameters(), lr=lr)
|
| 407 |
+
builder = SparseGradientBuilder(
|
| 408 |
+
model,
|
| 409 |
+
policy=policy,
|
| 410 |
+
inactive_mode=inactive_mode,
|
| 411 |
+
grad_beta=grad_beta,
|
| 412 |
+
mass_beta=mass_beta,
|
| 413 |
+
active_fraction=active_fraction,
|
| 414 |
+
refresh_interval=refresh_interval,
|
| 415 |
+
warmup_steps=warmup_steps,
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
history: List[Dict[str, float]] = []
|
| 419 |
+
|
| 420 |
+
for step in range(steps):
|
| 421 |
+
model.train()
|
| 422 |
+
idx = torch.randint(0, X.shape[0], (batch_size,), device=DEVICE)
|
| 423 |
+
xb, yb = X[idx], y[idx]
|
| 424 |
+
|
| 425 |
+
loss = F.cross_entropy(model(xb), yb)
|
| 426 |
+
|
| 427 |
+
opt.zero_grad(set_to_none=True)
|
| 428 |
+
loss.backward()
|
| 429 |
+
|
| 430 |
+
diagnostics = builder.build_and_install_approx_grads(step)
|
| 431 |
+
opt.step()
|
| 432 |
+
|
| 433 |
+
if step % 97 == 0 or step == steps - 1:
|
| 434 |
+
history.append({
|
| 435 |
+
"step": step,
|
| 436 |
+
"loss": float(loss.item()),
|
| 437 |
+
"accuracy": accuracy(model, X, y),
|
| 438 |
+
**diagnostics,
|
| 439 |
+
})
|
| 440 |
+
|
| 441 |
+
return model, history
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def sparse_rows(history: List[Dict[str, float]]) -> List[Dict[str, float]]:
|
| 445 |
+
return [row for row in history if row.get("is_refresh", 0.0) == 0.0]
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def avg_sparse_metric(history: List[Dict[str, float]], key: str) -> float:
|
| 449 |
+
rows = sparse_rows(history)
|
| 450 |
+
vals = [row[key] for row in rows]
|
| 451 |
+
if not vals:
|
| 452 |
+
return float("nan")
|
| 453 |
+
return sum(vals) / len(vals)
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def summarize_sparse_run(
|
| 457 |
+
policy: Policy,
|
| 458 |
+
inactive_mode: InactiveMode,
|
| 459 |
+
active_fraction: float,
|
| 460 |
+
history: List[Dict[str, float]],
|
| 461 |
+
) -> Dict[str, float | str]:
|
| 462 |
+
last = history[-1]
|
| 463 |
+
return {
|
| 464 |
+
"policy": policy,
|
| 465 |
+
"inactive_mode": inactive_mode,
|
| 466 |
+
"active_fraction": active_fraction,
|
| 467 |
+
"final_accuracy": last["accuracy"],
|
| 468 |
+
"final_loss": last["loss"],
|
| 469 |
+
"avg_cosine": avg_sparse_metric(history, "cosine_true_vs_approx"),
|
| 470 |
+
"avg_norm_ratio": avg_sparse_metric(history, "approx_norm_ratio"),
|
| 471 |
+
"avg_top20_grad_mass": avg_sparse_metric(history, "top20_gradient_mass"),
|
| 472 |
+
"avg_jacc_mag": avg_sparse_metric(history, "active_vs_oracle_mag"),
|
| 473 |
+
"avg_predmag_jacc": avg_sparse_metric(history, "predmag_vs_oracle_mag"),
|
| 474 |
+
"avg_stability": avg_sparse_metric(history, "active_stability"),
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def print_baseline(history: List[Dict[str, float]]) -> None:
|
| 479 |
+
print("\nBaseline full Adam")
|
| 480 |
+
for k, v in history[-1].items():
|
| 481 |
+
if isinstance(v, float):
|
| 482 |
+
print(f" {k:24s}: {v:.4f}")
|
| 483 |
+
else:
|
| 484 |
+
print(f" {k:24s}: {v}")
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def print_summary_table(rows: List[Dict[str, float | str]]) -> None:
|
| 488 |
+
print("\nSparse run summary")
|
| 489 |
+
header = (
|
| 490 |
+
f"{'policy':>20s} {'mode':>6s} {'active':>7s} "
|
| 491 |
+
f"{'acc':>7s} {'loss':>9s} {'cos':>7s} {'norm':>7s} "
|
| 492 |
+
f"{'top20':>7s} {'jacc':>7s} {'stable':>7s}"
|
| 493 |
+
)
|
| 494 |
+
print(header)
|
| 495 |
+
print("-" * len(header))
|
| 496 |
+
|
| 497 |
+
for row in rows:
|
| 498 |
+
print(
|
| 499 |
+
f"{str(row['policy']):>20s} "
|
| 500 |
+
f"{str(row['inactive_mode']):>6s} "
|
| 501 |
+
f"{float(row['active_fraction']):7.2f} "
|
| 502 |
+
f"{float(row['final_accuracy']):7.4f} "
|
| 503 |
+
f"{float(row['final_loss']):9.4f} "
|
| 504 |
+
f"{float(row['avg_cosine']):7.3f} "
|
| 505 |
+
f"{float(row['avg_norm_ratio']):7.3f} "
|
| 506 |
+
f"{float(row['avg_top20_grad_mass']):7.3f} "
|
| 507 |
+
f"{float(row['avg_jacc_mag']):7.3f} "
|
| 508 |
+
f"{float(row['avg_stability']):7.3f}"
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def main() -> None:
|
| 513 |
+
X, y = make_spiral()
|
| 514 |
+
|
| 515 |
+
baseline_model, baseline_hist = train_baseline(X, y)
|
| 516 |
+
print_baseline(baseline_hist)
|
| 517 |
+
|
| 518 |
+
active_fractions = [0.20, 0.10, 0.05, 0.02]
|
| 519 |
+
|
| 520 |
+
# Keep this matrix focused; it runs 20 sparse experiments.
|
| 521 |
+
experiment_plan: List[Tuple[Policy, InactiveMode, float]] = []
|
| 522 |
+
|
| 523 |
+
for active_fraction in active_fractions:
|
| 524 |
+
experiment_plan.append(("predicted_magnitude", "ema", active_fraction))
|
| 525 |
+
experiment_plan.append(("predicted_magnitude", "zero", active_fraction))
|
| 526 |
+
experiment_plan.append(("magnitude", "zero", active_fraction))
|
| 527 |
+
experiment_plan.append(("random", "zero", active_fraction))
|
| 528 |
+
|
| 529 |
+
summary_rows: List[Dict[str, float | str]] = []
|
| 530 |
+
|
| 531 |
+
for policy, inactive_mode, active_fraction in experiment_plan:
|
| 532 |
+
print(
|
| 533 |
+
f"\nRunning policy={policy}, inactive_mode={inactive_mode}, "
|
| 534 |
+
f"active_fraction={active_fraction:.2f}"
|
| 535 |
+
)
|
| 536 |
+
_, hist = train_sparse_policy(
|
| 537 |
+
X,
|
| 538 |
+
y,
|
| 539 |
+
policy=policy,
|
| 540 |
+
inactive_mode=inactive_mode,
|
| 541 |
+
active_fraction=active_fraction,
|
| 542 |
+
refresh_interval=10,
|
| 543 |
+
warmup_steps=100,
|
| 544 |
+
grad_beta=0.95,
|
| 545 |
+
mass_beta=0.95,
|
| 546 |
+
)
|
| 547 |
+
summary_rows.append(summarize_sparse_run(policy, inactive_mode, active_fraction, hist))
|
| 548 |
+
|
| 549 |
+
print_summary_table(summary_rows)
|
| 550 |
+
|
| 551 |
+
print("\nHow to read this:")
|
| 552 |
+
print(" predicted_magnitude + zero is the main practical test.")
|
| 553 |
+
print(" magnitude + zero is an oracle-ish upper bound using recent observed mass.")
|
| 554 |
+
print(" random + zero is the control.")
|
| 555 |
+
print(" acc close to baseline means sparse updates preserved learning.")
|
| 556 |
+
print(" cos near 1.0 means sparse update direction matches full gradient direction.")
|
| 557 |
+
print(" norm much below 1.0 means the sparse update is much smaller than full gradient.")
|
| 558 |
+
print(" top20 near 0.7+ means gradient mass is concentrated/heavy-tailed.")
|
| 559 |
+
print(" jacc above random means active-set prediction finds the true important blocks.")
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
if __name__ == "__main__":
|
| 563 |
+
main()
|
experiments/surprise_topk_gradient_prototype.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Surprise Top-K Gradient Prototype
|
| 3 |
+
|
| 4 |
+
This is a deliberately small, readable PyTorch experiment for testing the idea:
|
| 5 |
+
|
| 6 |
+
gradient_t ≈ predicted_gradient_t + sparse_surprising_residual_t
|
| 7 |
+
|
| 8 |
+
The code compares:
|
| 9 |
+
1. Baseline full SGD
|
| 10 |
+
2. SurpriseTopK training, where only high-surprise parameter blocks use the true
|
| 11 |
+
gradient on cheap steps; low-surprise blocks use a predicted/stale gradient.
|
| 12 |
+
|
| 13 |
+
Important caveat:
|
| 14 |
+
This prototype still computes full gradients on every step so we can evaluate the
|
| 15 |
+
approximation honestly. It simulates reduced backward/update entropy; it does not
|
| 16 |
+
yet provide real wall-clock acceleration. Real acceleration would require structured
|
| 17 |
+
partial backward passes, custom kernels, or graph-level masking.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import math
|
| 23 |
+
import random
|
| 24 |
+
from dataclasses import dataclass
|
| 25 |
+
from typing import Dict, List, Tuple
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn as nn
|
| 29 |
+
import torch.nn.functional as F
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# -----------------------------
|
| 33 |
+
# Reproducibility
|
| 34 |
+
# -----------------------------
|
| 35 |
+
|
| 36 |
+
SEED = 7
|
| 37 |
+
random.seed(SEED)
|
| 38 |
+
torch.manual_seed(SEED)
|
| 39 |
+
|
| 40 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# -----------------------------
|
| 44 |
+
# Toy data: 2-class spiral
|
| 45 |
+
# -----------------------------
|
| 46 |
+
|
| 47 |
+
def make_spiral(n_per_class: int = 512, noise: float = 0.2) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 48 |
+
"""Create a small nonlinear classification problem without external datasets."""
|
| 49 |
+
xs = []
|
| 50 |
+
ys = []
|
| 51 |
+
|
| 52 |
+
for class_id in range(2):
|
| 53 |
+
r = torch.linspace(0.0, 1.0, n_per_class)
|
| 54 |
+
theta = class_id * math.pi + r * 4.0 * math.pi
|
| 55 |
+
theta = theta + torch.randn(n_per_class) * noise
|
| 56 |
+
|
| 57 |
+
x = torch.stack([r * torch.sin(theta), r * torch.cos(theta)], dim=1)
|
| 58 |
+
y = torch.full((n_per_class,), class_id, dtype=torch.long)
|
| 59 |
+
|
| 60 |
+
xs.append(x)
|
| 61 |
+
ys.append(y)
|
| 62 |
+
|
| 63 |
+
X = torch.cat(xs, dim=0)
|
| 64 |
+
Y = torch.cat(ys, dim=0)
|
| 65 |
+
|
| 66 |
+
perm = torch.randperm(X.shape[0])
|
| 67 |
+
return X[perm].to(DEVICE), Y[perm].to(DEVICE)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# -----------------------------
|
| 71 |
+
# Model
|
| 72 |
+
# -----------------------------
|
| 73 |
+
|
| 74 |
+
class TinyMLP(nn.Module):
|
| 75 |
+
def __init__(self, width: int = 64):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.net = nn.Sequential(
|
| 78 |
+
nn.Linear(2, width),
|
| 79 |
+
nn.Tanh(),
|
| 80 |
+
nn.Linear(width, width),
|
| 81 |
+
nn.Tanh(),
|
| 82 |
+
nn.Linear(width, 2),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 86 |
+
return self.net(x)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def linear_layers(model: nn.Module) -> List[nn.Linear]:
|
| 90 |
+
return [m for m in model.modules() if isinstance(m, nn.Linear)]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# -----------------------------
|
| 94 |
+
# Block bookkeeping
|
| 95 |
+
# -----------------------------
|
| 96 |
+
|
| 97 |
+
@dataclass
|
| 98 |
+
class BlockRef:
|
| 99 |
+
"""A block is one output row of a Linear layer, plus its bias element if present."""
|
| 100 |
+
layer_index: int
|
| 101 |
+
row_index: int
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class SurpriseTopKUpdater:
|
| 105 |
+
"""
|
| 106 |
+
Applies a predicted-gradient + top-surprise true-gradient update.
|
| 107 |
+
|
| 108 |
+
Unit of sparsity:
|
| 109 |
+
One output row of each Linear layer.
|
| 110 |
+
|
| 111 |
+
Why rows?
|
| 112 |
+
Row blocks correspond roughly to neurons/features, and structured blocks are
|
| 113 |
+
much closer to real hardware speedup than individual unstructured weights.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
model: nn.Module,
|
| 119 |
+
lr: float = 0.05,
|
| 120 |
+
beta: float = 0.9,
|
| 121 |
+
active_fraction: float = 0.2,
|
| 122 |
+
refresh_interval: int = 10,
|
| 123 |
+
use_error_feedback: bool = True,
|
| 124 |
+
eps: float = 1e-12,
|
| 125 |
+
):
|
| 126 |
+
self.model = model
|
| 127 |
+
self.layers = linear_layers(model)
|
| 128 |
+
self.lr = lr
|
| 129 |
+
self.beta = beta
|
| 130 |
+
self.active_fraction = active_fraction
|
| 131 |
+
self.refresh_interval = refresh_interval
|
| 132 |
+
self.use_error_feedback = use_error_feedback
|
| 133 |
+
self.eps = eps
|
| 134 |
+
|
| 135 |
+
self.blocks: List[BlockRef] = []
|
| 136 |
+
for li, layer in enumerate(self.layers):
|
| 137 |
+
for row in range(layer.weight.shape[0]):
|
| 138 |
+
self.blocks.append(BlockRef(li, row))
|
| 139 |
+
|
| 140 |
+
# Predicted gradients, same shape as parameters.
|
| 141 |
+
self.pred_w: Dict[int, torch.Tensor] = {}
|
| 142 |
+
self.pred_b: Dict[int, torch.Tensor] = {}
|
| 143 |
+
|
| 144 |
+
# Error-feedback buffers accumulate information we did not apply.
|
| 145 |
+
self.err_w: Dict[int, torch.Tensor] = {}
|
| 146 |
+
self.err_b: Dict[int, torch.Tensor] = {}
|
| 147 |
+
|
| 148 |
+
# Surprise scores per block. Higher means more worth computing/updating exactly.
|
| 149 |
+
self.scores = torch.ones(len(self.blocks), device=DEVICE)
|
| 150 |
+
|
| 151 |
+
for li, layer in enumerate(self.layers):
|
| 152 |
+
self.pred_w[li] = torch.zeros_like(layer.weight.data)
|
| 153 |
+
self.err_w[li] = torch.zeros_like(layer.weight.data)
|
| 154 |
+
if layer.bias is not None:
|
| 155 |
+
self.pred_b[li] = torch.zeros_like(layer.bias.data)
|
| 156 |
+
self.err_b[li] = torch.zeros_like(layer.bias.data)
|
| 157 |
+
|
| 158 |
+
def _block_grad_vector(self, li: int, row: int) -> torch.Tensor:
|
| 159 |
+
layer = self.layers[li]
|
| 160 |
+
parts = [layer.weight.grad[row].flatten()]
|
| 161 |
+
if layer.bias is not None:
|
| 162 |
+
parts.append(layer.bias.grad[row].view(1))
|
| 163 |
+
return torch.cat(parts)
|
| 164 |
+
|
| 165 |
+
def _block_pred_vector(self, li: int, row: int) -> torch.Tensor:
|
| 166 |
+
layer = self.layers[li]
|
| 167 |
+
parts = [self.pred_w[li][row].flatten()]
|
| 168 |
+
if layer.bias is not None:
|
| 169 |
+
parts.append(self.pred_b[li][row].view(1))
|
| 170 |
+
return torch.cat(parts)
|
| 171 |
+
|
| 172 |
+
def _choose_active_blocks(self, step: int) -> torch.Tensor:
|
| 173 |
+
"""Return boolean mask over blocks."""
|
| 174 |
+
n_blocks = len(self.blocks)
|
| 175 |
+
|
| 176 |
+
# Full refresh: observe/update everything.
|
| 177 |
+
if step % self.refresh_interval == 0:
|
| 178 |
+
return torch.ones(n_blocks, dtype=torch.bool, device=DEVICE)
|
| 179 |
+
|
| 180 |
+
k = max(1, int(self.active_fraction * n_blocks))
|
| 181 |
+
active = torch.zeros(n_blocks, dtype=torch.bool, device=DEVICE)
|
| 182 |
+
top_idx = torch.topk(self.scores, k=k).indices
|
| 183 |
+
active[top_idx] = True
|
| 184 |
+
return active
|
| 185 |
+
|
| 186 |
+
@torch.no_grad()
|
| 187 |
+
def step(self, step: int) -> Dict[str, float]:
|
| 188 |
+
"""
|
| 189 |
+
Apply one optimizer step after loss.backward().
|
| 190 |
+
|
| 191 |
+
Returns diagnostics comparing the approximate update to the true gradient.
|
| 192 |
+
"""
|
| 193 |
+
active = self._choose_active_blocks(step)
|
| 194 |
+
|
| 195 |
+
true_flat = []
|
| 196 |
+
applied_flat = []
|
| 197 |
+
pred_flat = []
|
| 198 |
+
|
| 199 |
+
active_count = int(active.sum().item())
|
| 200 |
+
total_count = len(self.blocks)
|
| 201 |
+
|
| 202 |
+
for block_id, block in enumerate(self.blocks):
|
| 203 |
+
li = block.layer_index
|
| 204 |
+
row = block.row_index
|
| 205 |
+
layer = self.layers[li]
|
| 206 |
+
is_active = bool(active[block_id].item())
|
| 207 |
+
|
| 208 |
+
g_w = layer.weight.grad[row].clone()
|
| 209 |
+
g_b = layer.bias.grad[row].clone() if layer.bias is not None else None
|
| 210 |
+
|
| 211 |
+
# Error feedback makes skipped/prediction error come back later instead
|
| 212 |
+
# of disappearing forever.
|
| 213 |
+
if self.use_error_feedback:
|
| 214 |
+
g_w_eff = g_w + self.err_w[li][row]
|
| 215 |
+
g_b_eff = g_b + self.err_b[li][row] if g_b is not None else None
|
| 216 |
+
else:
|
| 217 |
+
g_w_eff = g_w
|
| 218 |
+
g_b_eff = g_b
|
| 219 |
+
|
| 220 |
+
p_w = self.pred_w[li][row].clone()
|
| 221 |
+
p_b = self.pred_b[li][row].clone() if layer.bias is not None else None
|
| 222 |
+
|
| 223 |
+
if is_active:
|
| 224 |
+
# Use the exact observed gradient for high-surprise blocks.
|
| 225 |
+
applied_w = g_w_eff
|
| 226 |
+
applied_b = g_b_eff
|
| 227 |
+
|
| 228 |
+
# Update predictor only where we pretend we actually observed the gradient.
|
| 229 |
+
self.pred_w[li][row].mul_(self.beta).add_(g_w, alpha=1.0 - self.beta)
|
| 230 |
+
if layer.bias is not None:
|
| 231 |
+
self.pred_b[li][row].mul_(self.beta).add_(g_b, alpha=1.0 - self.beta)
|
| 232 |
+
|
| 233 |
+
# Update surprise score: how wrong was the current predictor?
|
| 234 |
+
pred_vec = self._block_pred_vector(li, row)
|
| 235 |
+
grad_vec = self._block_grad_vector(li, row)
|
| 236 |
+
residual_norm = torch.norm(grad_vec - pred_vec)
|
| 237 |
+
grad_norm = torch.norm(grad_vec)
|
| 238 |
+
self.scores[block_id] = residual_norm / (grad_norm + self.eps)
|
| 239 |
+
else:
|
| 240 |
+
# Use the predicted/stale gradient for low-surprise blocks.
|
| 241 |
+
applied_w = p_w
|
| 242 |
+
applied_b = p_b
|
| 243 |
+
|
| 244 |
+
# Update error-feedback buffers.
|
| 245 |
+
if self.use_error_feedback:
|
| 246 |
+
self.err_w[li][row] = g_w_eff - applied_w
|
| 247 |
+
if layer.bias is not None:
|
| 248 |
+
self.err_b[li][row] = g_b_eff - applied_b
|
| 249 |
+
|
| 250 |
+
# Apply update.
|
| 251 |
+
layer.weight.data[row].add_(applied_w, alpha=-self.lr)
|
| 252 |
+
if layer.bias is not None:
|
| 253 |
+
layer.bias.data[row].add_(applied_b, alpha=-self.lr)
|
| 254 |
+
|
| 255 |
+
# Diagnostics.
|
| 256 |
+
true_parts = [g_w.flatten()]
|
| 257 |
+
app_parts = [applied_w.flatten()]
|
| 258 |
+
pred_parts = [p_w.flatten()]
|
| 259 |
+
|
| 260 |
+
if layer.bias is not None:
|
| 261 |
+
true_parts.append(g_b.view(1))
|
| 262 |
+
app_parts.append(applied_b.view(1))
|
| 263 |
+
pred_parts.append(p_b.view(1))
|
| 264 |
+
|
| 265 |
+
true_flat.append(torch.cat(true_parts))
|
| 266 |
+
applied_flat.append(torch.cat(app_parts))
|
| 267 |
+
pred_flat.append(torch.cat(pred_parts))
|
| 268 |
+
|
| 269 |
+
true_vec = torch.cat(true_flat)
|
| 270 |
+
applied_vec = torch.cat(applied_flat)
|
| 271 |
+
pred_vec = torch.cat(pred_flat)
|
| 272 |
+
|
| 273 |
+
cosine = F.cosine_similarity(true_vec, applied_vec, dim=0).item()
|
| 274 |
+
pred_explained = 1.0 - (
|
| 275 |
+
torch.norm(true_vec - pred_vec).pow(2) / (torch.norm(true_vec).pow(2) + self.eps)
|
| 276 |
+
).item()
|
| 277 |
+
|
| 278 |
+
# Heavy-tail diagnostic: how much surprise mass lives in the top 20% blocks?
|
| 279 |
+
k20 = max(1, int(0.2 * total_count))
|
| 280 |
+
sorted_scores = torch.sort(self.scores.detach(), descending=True).values
|
| 281 |
+
top20_mass = (sorted_scores[:k20].sum() / (sorted_scores.sum() + self.eps)).item()
|
| 282 |
+
|
| 283 |
+
# Clear gradients manually.
|
| 284 |
+
for p in self.model.parameters():
|
| 285 |
+
p.grad = None
|
| 286 |
+
|
| 287 |
+
return {
|
| 288 |
+
"active_fraction": active_count / total_count,
|
| 289 |
+
"cosine_true_vs_applied": cosine,
|
| 290 |
+
"pred_explained_fraction": pred_explained,
|
| 291 |
+
"top20_surprise_mass": top20_mass,
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
# -----------------------------
|
| 296 |
+
# Training loops
|
| 297 |
+
# -----------------------------
|
| 298 |
+
|
| 299 |
+
def accuracy(model: nn.Module, X: torch.Tensor, y: torch.Tensor) -> float:
|
| 300 |
+
model.eval()
|
| 301 |
+
with torch.no_grad():
|
| 302 |
+
pred = model(X).argmax(dim=1)
|
| 303 |
+
return (pred == y).float().mean().item()
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def train_baseline(
|
| 307 |
+
X: torch.Tensor,
|
| 308 |
+
y: torch.Tensor,
|
| 309 |
+
steps: int = 600,
|
| 310 |
+
batch_size: int = 128,
|
| 311 |
+
lr: float = 0.05,
|
| 312 |
+
) -> Tuple[nn.Module, List[Dict[str, float]]]:
|
| 313 |
+
model = TinyMLP().to(DEVICE)
|
| 314 |
+
opt = torch.optim.SGD(model.parameters(), lr=lr)
|
| 315 |
+
history = []
|
| 316 |
+
|
| 317 |
+
for step in range(steps):
|
| 318 |
+
idx = torch.randint(0, X.shape[0], (batch_size,), device=DEVICE)
|
| 319 |
+
xb, yb = X[idx], y[idx]
|
| 320 |
+
|
| 321 |
+
logits = model(xb)
|
| 322 |
+
loss = F.cross_entropy(logits, yb)
|
| 323 |
+
|
| 324 |
+
opt.zero_grad(set_to_none=True)
|
| 325 |
+
loss.backward()
|
| 326 |
+
opt.step()
|
| 327 |
+
|
| 328 |
+
if step % 25 == 0 or step == steps - 1:
|
| 329 |
+
history.append({
|
| 330 |
+
"step": step,
|
| 331 |
+
"loss": float(loss.item()),
|
| 332 |
+
"accuracy": accuracy(model, X, y),
|
| 333 |
+
})
|
| 334 |
+
|
| 335 |
+
return model, history
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def train_surprise_topk(
|
| 339 |
+
X: torch.Tensor,
|
| 340 |
+
y: torch.Tensor,
|
| 341 |
+
steps: int = 600,
|
| 342 |
+
batch_size: int = 128,
|
| 343 |
+
lr: float = 0.05,
|
| 344 |
+
active_fraction: float = 0.2,
|
| 345 |
+
refresh_interval: int = 10,
|
| 346 |
+
beta: float = 0.9,
|
| 347 |
+
) -> Tuple[nn.Module, List[Dict[str, float]]]:
|
| 348 |
+
model = TinyMLP().to(DEVICE)
|
| 349 |
+
updater = SurpriseTopKUpdater(
|
| 350 |
+
model,
|
| 351 |
+
lr=lr,
|
| 352 |
+
beta=beta,
|
| 353 |
+
active_fraction=active_fraction,
|
| 354 |
+
refresh_interval=refresh_interval,
|
| 355 |
+
use_error_feedback=True,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
history = []
|
| 359 |
+
|
| 360 |
+
for step in range(steps):
|
| 361 |
+
idx = torch.randint(0, X.shape[0], (batch_size,), device=DEVICE)
|
| 362 |
+
xb, yb = X[idx], y[idx]
|
| 363 |
+
|
| 364 |
+
logits = model(xb)
|
| 365 |
+
loss = F.cross_entropy(logits, yb)
|
| 366 |
+
loss.backward()
|
| 367 |
+
|
| 368 |
+
diagnostics = updater.step(step)
|
| 369 |
+
|
| 370 |
+
if step % 25 == 0 or step == steps - 1:
|
| 371 |
+
row = {
|
| 372 |
+
"step": step,
|
| 373 |
+
"loss": float(loss.item()),
|
| 374 |
+
"accuracy": accuracy(model, X, y),
|
| 375 |
+
**diagnostics,
|
| 376 |
+
}
|
| 377 |
+
history.append(row)
|
| 378 |
+
|
| 379 |
+
return model, history
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
# -----------------------------
|
| 383 |
+
# Main experiment
|
| 384 |
+
# -----------------------------
|
| 385 |
+
|
| 386 |
+
def print_last(label: str, history: List[Dict[str, float]]) -> None:
|
| 387 |
+
last = history[-1]
|
| 388 |
+
print(f"\n{label}")
|
| 389 |
+
for k, v in last.items():
|
| 390 |
+
if isinstance(v, float):
|
| 391 |
+
print(f" {k:28s}: {v:.4f}")
|
| 392 |
+
else:
|
| 393 |
+
print(f" {k:28s}: {v}")
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def main() -> None:
|
| 397 |
+
X, y = make_spiral(n_per_class=768, noise=0.18)
|
| 398 |
+
|
| 399 |
+
baseline_model, baseline_hist = train_baseline(X, y)
|
| 400 |
+
|
| 401 |
+
surprise_model, surprise_hist = train_surprise_topk(
|
| 402 |
+
X,
|
| 403 |
+
y,
|
| 404 |
+
active_fraction=0.2,
|
| 405 |
+
refresh_interval=10,
|
| 406 |
+
beta=0.9,
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
print_last("Baseline full SGD", baseline_hist)
|
| 410 |
+
print_last("Surprise Top-K simulated training", surprise_hist)
|
| 411 |
+
|
| 412 |
+
print("\nA few Surprise Top-K checkpoints:")
|
| 413 |
+
for row in surprise_hist[:: max(1, len(surprise_hist) // 8)]:
|
| 414 |
+
print(
|
| 415 |
+
f"step={row['step']:4d} "
|
| 416 |
+
f"loss={row['loss']:.4f} "
|
| 417 |
+
f"acc={row['accuracy']:.3f} "
|
| 418 |
+
f"active={row['active_fraction']:.2f} "
|
| 419 |
+
f"cos={row['cosine_true_vs_applied']:.3f} "
|
| 420 |
+
f"pred_expl={row['pred_explained_fraction']:.3f} "
|
| 421 |
+
f"top20_mass={row['top20_surprise_mass']:.3f}"
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
if __name__ == "__main__":
|
| 426 |
+
main()
|
triton_sparse.py → experiments/triton_sparse.py
RENAMED
|
File without changes
|
triton_v2.py → experiments/triton_v2.py
RENAMED
|
File without changes
|
experiments/uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
paper/main.tex
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
\documentclass[11pt]{article}
|
| 2 |
+
|
| 3 |
+
\usepackage[margin=1in]{geometry}
|
| 4 |
+
\usepackage{amsmath,amssymb}
|
| 5 |
+
\usepackage{graphicx}
|
| 6 |
+
\usepackage{microtype}
|
| 7 |
+
\usepackage{booktabs}
|
| 8 |
+
\usepackage{tabularx}
|
| 9 |
+
\usepackage{hyperref}
|
| 10 |
+
|
| 11 |
+
\hypersetup{
|
| 12 |
+
colorlinks=true,
|
| 13 |
+
linkcolor=blue,
|
| 14 |
+
citecolor=blue,
|
| 15 |
+
urlcolor=blue,
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
\title{%
|
| 19 |
+
\textbf{Zero-Copy Sparse Backpropagation}: \\[0.3em]
|
| 20 |
+
\large Temporal Gradient Tracking for
|
| 21 |
+
Faster, Regularized LLM Training
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
\author{
|
| 25 |
+
Daniel Owen van Dommelen\\
|
| 26 |
+
\textit{Independent Research - WORKING DRAFT}\\
|
| 27 |
+
\texttt{theapemachine@gmail.com}
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
\date{\today}
|
| 31 |
+
|
| 32 |
+
\begin{document}
|
| 33 |
+
\maketitle
|
| 34 |
+
|
| 35 |
+
\begin{abstract}
|
| 36 |
+
We describe \emph{Predictive Chunked Sparsity}: fixed top-$k$ row-chunks for
|
| 37 |
+
sparse $dW$, selected by an EMA of past chunk-gradient norms, with contiguous
|
| 38 |
+
slices for PyTorch-style GEMMs. On \textbf{Apple MPS} (full 6-layer runs:
|
| 39 |
+
$B{=}8$, $T{=}256$, chunk 64, $10\%$ active, 2000 steps), sparse training is
|
| 40 |
+
\textbf{slower} at $d_{\text{model}}{=}512$ ($\sim$1.22$\times$ higher ms/step than
|
| 41 |
+
dense for both $G_X$ modes) but \textbf{faster} at $d{=}2048$ ($\sim$1.18$\times$
|
| 42 |
+
and $\sim$1.221$\times$ speedup for full-$G_X$ and sparse-$G_X$ respectively,
|
| 43 |
+
with validation loss reported in Table~\ref{tab:mps-e2e}).
|
| 44 |
+
|
| 45 |
+
On \textbf{NVIDIA T4}, an isolated single-FFN timing harness (100 iters, fp32,
|
| 46 |
+
same $B,T$, chunk 64, $10\%$ active) shows full-$G_X$ totals from
|
| 47 |
+
1.02$\times$ at $d{=}256$ to 1.35$\times$ at $d{=}2048$
|
| 48 |
+
(Table~\ref{tab:t4-ffn-micro}). A fused \textbf{Triton} backward passes numeric
|
| 49 |
+
checks (Table~\ref{tab:triton-correctness}); isolated backward on T4 improves
|
| 50 |
+
over dense for $d\ge 512$ but can trail PyTorch at $d{=}256$
|
| 51 |
+
(Table~\ref{tab:triton-backward}). Short \textbf{T4 end-to-end} training (100
|
| 52 |
+
steps) shows modest PyLoop gains at $d{=}512$/1024 and Triton autotune/noise
|
| 53 |
+
hurting at small scale (Table~\ref{tab:t4-e2e}). EMA--oracle chunk overlap on one
|
| 54 |
+
seed is in Table~\ref{tab:ema-overlap}; multi-seed long runs were pending at
|
| 55 |
+
draft time.
|
| 56 |
+
\end{abstract}
|
| 57 |
+
|
| 58 |
+
\section{Introduction}
|
| 59 |
+
Training transformers is dominated by dense matmuls. Some work reports
|
| 60 |
+
heavy-tailed gradient coordinates; whether that yields wall-clock savings depends
|
| 61 |
+
on implementation and hardware. Dynamic sparsity often hits irregular memory
|
| 62 |
+
access and, for variable masks, possible host--device coordination for shapes.
|
| 63 |
+
We use \emph{fixed-cardinality} chunk masks, an EMA scorer, cosine annealing,
|
| 64 |
+
and strided views (and optionally Triton) so active tiles map to dense GEMMs.
|
| 65 |
+
Contributions are \textbf{(1)} the algorithmic recipe, \textbf{(2)}
|
| 66 |
+
reproducible tables for MPS full training, T4 microbenchmarks, Triton
|
| 67 |
+
correctness and speed, short T4 E2E, chunk-size timing, and \textbf{(3)} honest
|
| 68 |
+
limits: speedups are width-, backend-, and workload-dependent.
|
| 69 |
+
|
| 70 |
+
\section{Methodology: Predictive Chunked Sparsity}
|
| 71 |
+
Linear $W\in\mathbb{R}^{O\times I}$ is split into $N$ row chunks of size $C$.
|
| 72 |
+
Binary mask $A\in\{0,1\}^N$ picks active chunks; inactive $dW$ is zeroed into
|
| 73 |
+
the optimizer. EMA on observed chunk norms $M_c^{(t)}=\beta
|
| 74 |
+
M_c^{(t-1)}+(1-\beta)\|G_{W_c}\|_2$ (active); $M_c^{(t)}=\gamma M_c^{(t-1)}$
|
| 75 |
+
(inactive). Top-$k$ chunks from $M^{(t-1)}$ fix $A$ at step $t$. Cosine schedule
|
| 76 |
+
$S(t)$ warms up fully dense then anneals toward $S_{\text{target}}$. With AdamW,
|
| 77 |
+
$g{=}0$ on inactive weights yields decaying moments (``phantom momentum'')---
|
| 78 |
+
standard Adam side effect, not a separate contribution.
|
| 79 |
+
|
| 80 |
+
\section{Systems}
|
| 81 |
+
Fixed $k$ avoids mask-derived index tensor sizes. Chunk rows are contiguous
|
| 82 |
+
slices (\texttt{gy\_flat[:, s:e]}). PyTorch normally implements that as a view;
|
| 83 |
+
exact behavior is version-dependent. A Python loop over active chunks issues
|
| 84 |
+
multiple kernel launches; Triton fusion targets that overhead (see
|
| 85 |
+
Table~\ref{tab:triton-backward}).
|
| 86 |
+
|
| 87 |
+
\section{Experiments and Results}
|
| 88 |
+
All numbers below are from recorded runs; GPU, hyperparameters, and seed are
|
| 89 |
+
stated per table. We do not claim universal ranking of backends.
|
| 90 |
+
|
| 91 |
+
\subsection{Full training on Apple MPS (author runs)}
|
| 92 |
+
Six layers, $B{=}8$, $T{=}256$, chunk\_size${=}64$, $10\%$ active chunks, 2000
|
| 93 |
+
optimization steps. Times are total wall for 2000 steps; ms/step derived.
|
| 94 |
+
|
| 95 |
+
\begin{table}[t]
|
| 96 |
+
\centering
|
| 97 |
+
\caption{MPS full training (single-seed author configuration per run).}
|
| 98 |
+
\label{tab:mps-e2e}
|
| 99 |
+
\begin{tabular}{l l r r r}
|
| 100 |
+
\toprule
|
| 101 |
+
$d_{\text{model}}$ & Run & Time (s) & ms/step & Val.\ loss \\
|
| 102 |
+
\midrule
|
| 103 |
+
512 & \texttt{dense\_baseline} & 74.77 & 99.70 & 5.3142 \\
|
| 104 |
+
512 & \texttt{sparse\_full\_dX} & 91.04 & 121.38 & 5.4141 \\
|
| 105 |
+
512 & \texttt{sparse\_sparse\_dX} & 93.33 & 124.44 & 5.5467 \\
|
| 106 |
+
\midrule
|
| 107 |
+
2048 & \texttt{dense\_baseline} & 1035.84 & 591.91 & 6.0264 \\
|
| 108 |
+
2048 & \texttt{sparse\_full\_dX} & 875.51 & 500.29 & 5.9807 \\
|
| 109 |
+
2048 & \texttt{sparse\_sparse\_dX} & 847.22 & 484.13 & 6.0231 \\
|
| 110 |
+
\bottomrule
|
| 111 |
+
\end{tabular}
|
| 112 |
+
\end{table}
|
| 113 |
+
|
| 114 |
+
At $d{=}512$, sparse ms/step is $\sim$1.22$\times$ (\texttt{sparse\_full\_dX})
|
| 115 |
+
and $\sim$1.25$\times$ (\texttt{sparse\_sparse\_dX}) vs.\ dense---\emph{slower}.
|
| 116 |
+
At $d{=}2048$, sparse is $\sim$1.18$\times$ and $\sim$1.22$\times$
|
| 117 |
+
\emph{faster}. Validation loss at $d{=}2048$ is best for
|
| 118 |
+
\texttt{sparse\_full\_dX} in this table; at $d{=}512$ dense is best.
|
| 119 |
+
|
| 120 |
+
\subsection{Isolated FFN layer microbenchmark (T4)}
|
| 121 |
+
One FFN block, $M{=}2048$, $B{=}8$, $T{=}256$, chunk\_size${=}64$, $10\%$ active,
|
| 122 |
+
fp32, 100 iterations. Components: forward, $dX$, $dW$ dense vs.\ sparse;
|
| 123 |
+
\emph{full\_$G_X$} total = sum with dense $dX$.
|
| 124 |
+
|
| 125 |
+
\begin{table}[t]
|
| 126 |
+
\centering
|
| 127 |
+
\caption{T4: per--FFN-layer times (ms). Spd. $=$ Tot.\ den.\,/{}Tot.\ sp.f.;
|
| 128 |
+
sparse total uses dense $dX$ (full\_dX).}
|
| 129 |
+
\label{tab:t4-ffn-micro}
|
| 130 |
+
\resizebox{\linewidth}{!}{%
|
| 131 |
+
\footnotesize
|
| 132 |
+
\begin{tabular}{r r r r r r r r r r}
|
| 133 |
+
\toprule
|
| 134 |
+
$d_{\text{model}}$ & FFN dim & Params & Fwd & $dX$ & $dW_{\mathrm{d}}$ &
|
| 135 |
+
$dW_{\mathrm{s}}$ & Tot.\ den. & Tot.\ sp.f. & Spd. \\
|
| 136 |
+
\midrule
|
| 137 |
+
256 & 1024 & 0.3M & 0.27 & 0.21 & 0.27 & 0.26 & 0.75 & 0.74 & 1.02$\times$ \\
|
| 138 |
+
384 & 1536 & 0.6M & 0.52 & 0.69 & 0.61 & 0.18 & 1.82 & 1.39 & 1.31$\times$ \\
|
| 139 |
+
512 & 2048 & 1.0M & 1.00 & 1.01 & 0.97 & 0.26 & 2.99 & 2.28 & 1.31$\times$ \\
|
| 140 |
+
768 & 3072 & 2.4M & 2.16 & 2.25 & 2.05 & 0.40 & 6.46 & 4.81 & 1.34$\times$ \\
|
| 141 |
+
1024 & 4096 & 4.2M & 3.69 & 3.90 & 3.35 & 0.59 & 10.95 & 8.18 & 1.34$\times$ \\
|
| 142 |
+
1536 & 6144 & 9.4M & 10.33 & 9.03 & 8.14 & 1.30 & 27.50 & 20.66 & 1.33$\times$ \\
|
| 143 |
+
2048 & 8192 & 16.8M & 14.76 & 15.57 & 13.19 & 1.93 & 43.51 & 32.26 & 1.35$\times$ \\
|
| 144 |
+
\bottomrule
|
| 145 |
+
\end{tabular}%
|
| 146 |
+
}
|
| 147 |
+
\end{table}
|
| 148 |
+
|
| 149 |
+
If $dW_{\mathrm{dense}}$ were removed from the dense total, a simple
|
| 150 |
+
illustrative ratio (using the measured forward+$dX$ share) implies a ceiling
|
| 151 |
+
around $\sim$1.42--1.48$\times$ for this harness; crossover for net speedup vs.\
|
| 152 |
+
dense full-$G_X$ is near $d_{\text{model}}\approx 384$ in this table.
|
| 153 |
+
|
| 154 |
+
\subsection{Triton numeric checks (T4)}
|
| 155 |
+
Max absolute errors vs.\ reference (fp32 tolerances in experiment script); all
|
| 156 |
+
marked passing in the run log.
|
| 157 |
+
|
| 158 |
+
\begin{table}[t]
|
| 159 |
+
\centering
|
| 160 |
+
\caption{Triton backward vs.\ reference: max abs error.}
|
| 161 |
+
\label{tab:triton-correctness}
|
| 162 |
+
\begin{tabular}{r r r r r r r}
|
| 163 |
+
\toprule
|
| 164 |
+
$d_{\mathrm{in}}$ & $d_{\mathrm{out}}$ & ch. & $\max|dW|$ & $\max|db|$ & $\max|dX|$ & OK \\
|
| 165 |
+
\midrule
|
| 166 |
+
512 & 2048 & 64 & 0.000320 & 0.000023 & 0.000042 & $\checkmark$ \\
|
| 167 |
+
1024 & 4096 & 64 & 0.000443 & 0.000021 & 0.000092 & $\checkmark$ \\
|
| 168 |
+
256 & 1024 & 32 & 0.000275 & 0.000038 & 0.000019 & $\checkmark$ \\
|
| 169 |
+
\bottomrule
|
| 170 |
+
\end{tabular}
|
| 171 |
+
\end{table}
|
| 172 |
+
|
| 173 |
+
\subsection{Isolated backward: Dense vs.\ PyLoop vs.\ Triton (T4)}
|
| 174 |
+
$M{=}2048$, chunk\_size${=}64$, $10\%$ active, full\_$G_X$ mode, 50 iterations
|
| 175 |
+
post-warmup. Times are full backward ms for the timed region (as recorded).
|
| 176 |
+
|
| 177 |
+
\begin{table}[t]
|
| 178 |
+
\centering
|
| 179 |
+
\caption{T4 isolated backward (ms). Triton/Dense $=$ dense/time\_triton.}
|
| 180 |
+
\label{tab:triton-backward}
|
| 181 |
+
\resizebox{\linewidth}{!}{%
|
| 182 |
+
\footnotesize
|
| 183 |
+
\begin{tabular}{r r r r r r r r r}
|
| 184 |
+
\toprule
|
| 185 |
+
$d_{\text{model}}$ & FFN & Active ch. & Dense & PyLoop & Triton &
|
| 186 |
+
T/Dense & T/PyLoop \\
|
| 187 |
+
\midrule
|
| 188 |
+
256 & 1024 & 1 & 0.39 & 0.40 & 0.46 & 0.85$\times$ & 0.88$\times$ \\
|
| 189 |
+
512 & 2048 & 3 & 1.96 & 1.30 & 1.16 & 1.69$\times$ & 1.12$\times$ \\
|
| 190 |
+
768 & 3072 & 4 & 4.29 & 2.52 & 2.51 & 1.70$\times$ & 1.00$\times$ \\
|
| 191 |
+
1024 & 4096 & 6 & 7.29 & 4.37 & 4.30 & 1.70$\times$ & 1.02$\times$ \\
|
| 192 |
+
1536 & 6144 & 9 & 17.32 & 10.04 & 9.78 & 1.77$\times$ & 1.03$\times$ \\
|
| 193 |
+
2048 & 8192 & 12 & 29.14 & 17.20 & 16.89 & 1.73$\times$ & 1.02$\times$ \\
|
| 194 |
+
\bottomrule
|
| 195 |
+
\end{tabular}%
|
| 196 |
+
}
|
| 197 |
+
\end{table}
|
| 198 |
+
|
| 199 |
+
\noindent\textbf{Triton with both $dW$ and $dX$ sparse} (same harness family;
|
| 200 |
+
user-reported row):
|
| 201 |
+
|
| 202 |
+
\begin{table}[h]
|
| 203 |
+
\centering
|
| 204 |
+
\begin{tabular}{r r r r}
|
| 205 |
+
\toprule
|
| 206 |
+
$d_{\text{model}}$ & Dense (ms) & Triton\_all (ms) & Speedup \\
|
| 207 |
+
\midrule
|
| 208 |
+
512 & 1.96 & 0.41 & 4.83$\times$ \\
|
| 209 |
+
1024 & 7.06 & 1.07 & 6.58$\times$ \\
|
| 210 |
+
2048 & 29.00 & 3.71 & 7.81$\times$ \\
|
| 211 |
+
\bottomrule
|
| 212 |
+
\end{tabular}
|
| 213 |
+
\end{table}
|
| 214 |
+
|
| 215 |
+
At $d{=}256$, Triton is slower than dense in Table~\ref{tab:triton-backward}
|
| 216 |
+
(0.85$\times$); at $d{=}512$, PyTorch single-kernel launches can still be hard
|
| 217 |
+
to beat for only three active chunks.
|
| 218 |
+
|
| 219 |
+
\subsection{End-to-end training on T4 (100 steps)}
|
| 220 |
+
Six layers, 8 heads, $B{=}8$, $T{=}256$, chunk\_size${=}64$, $10\%$ active,
|
| 221 |
+
seed${=}42$, AdamW lr$=$5e-4, full\_$G_X$. $d{=}2048$ did not fit 16GB T4.
|
| 222 |
+
|
| 223 |
+
\begin{table}[t]
|
| 224 |
+
\centering
|
| 225 |
+
\caption{T4 E2E (100 steps); ``vs Dense'' is dense/ms\_mode.}
|
| 226 |
+
\label{tab:t4-e2e}
|
| 227 |
+
\begin{tabular}{r l r r r}
|
| 228 |
+
\toprule
|
| 229 |
+
$d_{\text{model}}$ & Mode & ms/step & vs.\ Dense & Val.\ loss \\
|
| 230 |
+
\midrule
|
| 231 |
+
512 & dense & 184.6 & 1.00$\times$ & 5.6954 \\
|
| 232 |
+
512 & pyloop & 179.0 & 1.03$\times$ & 5.8683 \\
|
| 233 |
+
512 & triton & 196.0 & 0.94$\times$ & 5.8683 \\
|
| 234 |
+
\midrule
|
| 235 |
+
1024 & dense & 451.5 & 1.00$\times$ & 5.5300 \\
|
| 236 |
+
1024 & pyloop & 435.6 & 1.04$\times$ & 5.4803 \\
|
| 237 |
+
1024 & triton & 441.0 & 1.02$\times$ & 5.4800 \\
|
| 238 |
+
\bottomrule
|
| 239 |
+
\end{tabular}
|
| 240 |
+
\end{table}
|
| 241 |
+
|
| 242 |
+
Triton E2E at $d{=}512$ is slower than dense here; autotune and short-run
|
| 243 |
+
overhead dominate at small scale in the author's log.
|
| 244 |
+
|
| 245 |
+
\subsection{EMA vs.\ oracle chunk overlap (T4)}
|
| 246 |
+
$d{=}512$, 6 layers, chunk\_size${=}64$, $10\%$ active, 350 steps, seed${=}42$;
|
| 247 |
+
first check step ${=}250$ post-anneal schedule. Jaccard/Recall vs.\ dense-oracle
|
| 248 |
+
top-$k$ (as implemented in experiment).
|
| 249 |
+
|
| 250 |
+
\begin{table}[t]
|
| 251 |
+
\centering
|
| 252 |
+
\caption{Predictor overlap (single seed; multi-seed long runs were pending).}
|
| 253 |
+
\label{tab:ema-overlap}
|
| 254 |
+
\begin{tabular}{r r r}
|
| 255 |
+
\toprule
|
| 256 |
+
Step & Jaccard & Recall \\
|
| 257 |
+
\midrule
|
| 258 |
+
250 & 0.6000 & 0.7500 \\
|
| 259 |
+
275 & 0.6552 & 0.7917 \\
|
| 260 |
+
300 & 0.7778 & 0.8750 \\
|
| 261 |
+
325 & 0.6000 & 0.7500 \\
|
| 262 |
+
\bottomrule
|
| 263 |
+
\end{tabular}
|
| 264 |
+
\end{table}
|
| 265 |
+
|
| 266 |
+
\subsection{Chunk size vs.\ step time (T4, PyLoop)}
|
| 267 |
+
$d{=}512$, 6 layers, $10\%$ active, seed${=}42$, 50 training steps (warmup;
|
| 268 |
+
loss not converged---timing only).
|
| 269 |
+
|
| 270 |
+
\begin{table}[t]
|
| 271 |
+
\centering
|
| 272 |
+
\caption{ms/step vs.\ chunk size (PyLoop backend).}
|
| 273 |
+
\label{tab:chunk-size}
|
| 274 |
+
\begin{tabular}{r r}
|
| 275 |
+
\toprule
|
| 276 |
+
Chunk size & ms/step \\
|
| 277 |
+
\midrule
|
| 278 |
+
16 & 601.4 \\
|
| 279 |
+
32 & 453.0 \\
|
| 280 |
+
64 & 321.5 \\
|
| 281 |
+
128 & 251.3 \\
|
| 282 |
+
256 & 219.8 \\
|
| 283 |
+
\bottomrule
|
| 284 |
+
\end{tabular}
|
| 285 |
+
\end{table}
|
| 286 |
+
|
| 287 |
+
Larger chunks $\Rightarrow$ fewer Python iterations per layer in this backend.
|
| 288 |
+
|
| 289 |
+
\subsection{Pending experiments (snapshot)}
|
| 290 |
+
At draft time, additional A10G jobs were in flight, e.g.\ internal IDs
|
| 291 |
+
\texttt{69f38371d70108f37ace1cae} (multi-baseline 2000-step suite),
|
| 292 |
+
\texttt{69f395b3d70108f37ace1cee} ($d$ scaling), and
|
| 293 |
+
\texttt{69f3af45d2c8bd8662bd419d} (E2E Triton including $d{=}2048$). Treat these
|
| 294 |
+
only as lab run pointers.
|
| 295 |
+
|
| 296 |
+
\section{Conclusion}
|
| 297 |
+
Chunked EMA sparsity is not uniformly faster: \textbf{MPS} shows a crossover in
|
| 298 |
+
$d_{\text{model}}$ between 512 and 2048 for full training;
|
| 299 |
+
\textbf{T4} microbenchmarks monotonically favor sparse full-$G_X$ totals from
|
| 300 |
+
$d{\approx}384$ upward to 1.35$\times$ at $d{=}2048$ in Table~\ref{tab:t4-ffn-micro},
|
| 301 |
+
while \textbf{T4 E2E} at 100 steps shows small PyLoop wins and Triton not yet
|
| 302 |
+
winning at $d{=}512$. Triton shows large factors when both $dW$ and $dX$ are
|
| 303 |
+
sparse in the isolated harness, subject to training-quality tradeoffs not fully
|
| 304 |
+
tabulated here. Future work: complete multi-seed tables and fused-kernel E2E at
|
| 305 |
+
large $d$.
|
| 306 |
+
|
| 307 |
+
\end{document}
|
sparse_transformer_v18_fast_knn.py
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sparse Transformer v18: Fast Chunked Sparse Backward + KNN Sensor Scheduler.
|
| 3 |
+
|
| 4 |
+
This plugs the v16/v17 KNN sensor scheduler into the real chunked sparse backward path.
|
| 5 |
+
It compares dense, EMA-topk sparse, KNN sparse, and random sparse in full_dX and sparse_dX modes.
|
| 6 |
+
|
| 7 |
+
Run:
|
| 8 |
+
python3 sparse_transformer_v18_fast_knn.py --device mps --benchmark_sync
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import math
|
| 15 |
+
import random
|
| 16 |
+
import time
|
| 17 |
+
from typing import Dict, List, Literal, Optional, Tuple
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
torch.set_num_threads(1)
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
|
| 25 |
+
Scheduler = Literal["dense", "ema_topk", "knn_scheduler", "random"]
|
| 26 |
+
BackwardMode = Literal["dense_baseline", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def sync_device(device: str) -> None:
|
| 30 |
+
if device == "cuda" and torch.cuda.is_available():
|
| 31 |
+
torch.cuda.synchronize()
|
| 32 |
+
elif device == "mps" and hasattr(torch, "mps"):
|
| 33 |
+
torch.mps.synchronize()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def set_seed(seed: int) -> None:
|
| 37 |
+
random.seed(seed)
|
| 38 |
+
torch.manual_seed(seed)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def make_cpu_generator(seed: int) -> torch.Generator:
|
| 42 |
+
gen = torch.Generator(device="cpu")
|
| 43 |
+
gen.manual_seed(seed)
|
| 44 |
+
return gen
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
|
| 48 |
+
rng = random.Random(seed)
|
| 49 |
+
words = [
|
| 50 |
+
"ada", "turing", "grace", "lovelace", "gradients", "tokens", "circuits",
|
| 51 |
+
"features", "boldly", "strangely", "matrix", "attention", "kernel", "entropy", "signal",
|
| 52 |
+
]
|
| 53 |
+
return "\n".join(
|
| 54 |
+
" ".join(rng.choices(words, k=rng.randint(4, 10))) + "." for _ in range(n_sentences)
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class CharCorpus:
|
| 59 |
+
def __init__(self, text: str, block_size: int, device: str):
|
| 60 |
+
chars = sorted(set(text))
|
| 61 |
+
self.stoi = {ch: i for i, ch in enumerate(chars)}
|
| 62 |
+
self.vocab_size = len(chars)
|
| 63 |
+
self.block_size = block_size
|
| 64 |
+
self.device = device
|
| 65 |
+
data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
|
| 66 |
+
self.train_data = data[: int(0.9 * len(data))]
|
| 67 |
+
self.val_data = data[int(0.9 * len(data)) :]
|
| 68 |
+
|
| 69 |
+
def get_batch(self, split: str, batch_size: int, generator: Optional[torch.Generator] = None):
|
| 70 |
+
data = self.train_data if split == "train" else self.val_data
|
| 71 |
+
ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator)
|
| 72 |
+
x = torch.stack([data[i : i + self.block_size] for i in ix])
|
| 73 |
+
y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
|
| 74 |
+
return x.to(self.device), y.to(self.device)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class ChunkedMaskedLinear(torch.autograd.Function):
|
| 78 |
+
@staticmethod
|
| 79 |
+
def forward(ctx, x, weight, bias, active_chunks, chunk_size: int, sparse_dx: bool):
|
| 80 |
+
ctx.save_for_backward(x, weight, active_chunks)
|
| 81 |
+
ctx.has_bias = bias is not None
|
| 82 |
+
ctx.sparse_dx = bool(sparse_dx)
|
| 83 |
+
ctx.chunk_size = int(chunk_size)
|
| 84 |
+
return F.linear(x, weight, bias)
|
| 85 |
+
|
| 86 |
+
@staticmethod
|
| 87 |
+
def backward(ctx, grad_y):
|
| 88 |
+
x, weight, active_chunks = ctx.saved_tensors
|
| 89 |
+
chunk_size = ctx.chunk_size
|
| 90 |
+
x_flat = x.reshape(-1, x.shape[-1])
|
| 91 |
+
gy_flat = grad_y.reshape(-1, grad_y.shape[-1])
|
| 92 |
+
grad_w = torch.zeros_like(weight)
|
| 93 |
+
grad_b = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if ctx.has_bias else None
|
| 94 |
+
grad_x_flat = torch.zeros_like(x_flat) if ctx.sparse_dx else gy_flat @ weight
|
| 95 |
+
|
| 96 |
+
for c_idx in active_chunks.tolist():
|
| 97 |
+
start = int(c_idx) * chunk_size
|
| 98 |
+
end = start + chunk_size
|
| 99 |
+
gy_slice = gy_flat[:, start:end]
|
| 100 |
+
w_slice = weight[start:end, :]
|
| 101 |
+
grad_w[start:end, :] = gy_slice.transpose(0, 1) @ x_flat
|
| 102 |
+
if grad_b is not None:
|
| 103 |
+
grad_b[start:end] = gy_slice.sum(dim=0)
|
| 104 |
+
if ctx.sparse_dx:
|
| 105 |
+
grad_x_flat += gy_slice @ w_slice
|
| 106 |
+
return grad_x_flat.reshape(x.shape), grad_w, grad_b, None, None, None
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class SparseLinear(nn.Linear):
|
| 110 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
| 111 |
+
super().__init__(in_features, out_features, bias=bias)
|
| 112 |
+
self.sparse_enabled = False
|
| 113 |
+
self.sparse_dx = False
|
| 114 |
+
self.active_chunks: Optional[torch.Tensor] = None
|
| 115 |
+
self.chunk_size = 64
|
| 116 |
+
|
| 117 |
+
def forward(self, x):
|
| 118 |
+
if not self.sparse_enabled or self.active_chunks is None:
|
| 119 |
+
return F.linear(x, self.weight, self.bias)
|
| 120 |
+
return ChunkedMaskedLinear.apply(x, self.weight, self.bias, self.active_chunks, self.chunk_size, self.sparse_dx)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class CausalSelfAttention(nn.Module):
|
| 124 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 125 |
+
super().__init__()
|
| 126 |
+
assert n_embd % n_head == 0
|
| 127 |
+
self.n_head = n_head
|
| 128 |
+
self.head_dim = n_embd // n_head
|
| 129 |
+
self.c_attn = SparseLinear(n_embd, 3 * n_embd)
|
| 130 |
+
self.c_proj = SparseLinear(n_embd, n_embd)
|
| 131 |
+
self.dropout = nn.Dropout(dropout)
|
| 132 |
+
self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
|
| 133 |
+
|
| 134 |
+
def forward(self, x):
|
| 135 |
+
B, T, C = x.shape
|
| 136 |
+
qkv = self.c_attn(x)
|
| 137 |
+
q, k, v = qkv.split(C, dim=2)
|
| 138 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 139 |
+
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 140 |
+
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 141 |
+
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 142 |
+
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
|
| 143 |
+
att = F.softmax(att, dim=-1)
|
| 144 |
+
att = self.dropout(att)
|
| 145 |
+
y = att @ v
|
| 146 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 147 |
+
return self.c_proj(y)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class FeedForward(nn.Module):
|
| 151 |
+
def __init__(self, n_embd: int, dropout: float):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.c_fc = SparseLinear(n_embd, 4 * n_embd)
|
| 154 |
+
self.c_proj = SparseLinear(4 * n_embd, n_embd)
|
| 155 |
+
self.dropout = nn.Dropout(dropout)
|
| 156 |
+
|
| 157 |
+
def forward(self, x):
|
| 158 |
+
return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class Block(nn.Module):
|
| 162 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 163 |
+
super().__init__()
|
| 164 |
+
self.ln1 = nn.LayerNorm(n_embd)
|
| 165 |
+
self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
|
| 166 |
+
self.ln2 = nn.LayerNorm(n_embd)
|
| 167 |
+
self.mlp = FeedForward(n_embd, dropout)
|
| 168 |
+
|
| 169 |
+
def forward(self, x):
|
| 170 |
+
x = x + self.attn(self.ln1(x))
|
| 171 |
+
x = x + self.mlp(self.ln2(x))
|
| 172 |
+
return x
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class MiniGPT(nn.Module):
|
| 176 |
+
def __init__(self, vocab_size: int, block_size: int, n_layer: int, n_head: int, n_embd: int, dropout: float):
|
| 177 |
+
super().__init__()
|
| 178 |
+
self.tok_emb = nn.Embedding(vocab_size, n_embd)
|
| 179 |
+
self.pos_emb = nn.Embedding(block_size, n_embd)
|
| 180 |
+
self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
|
| 181 |
+
self.ln_f = nn.LayerNorm(n_embd)
|
| 182 |
+
self.lm_head = nn.Linear(n_embd, vocab_size)
|
| 183 |
+
|
| 184 |
+
def forward(self, idx, targets=None):
|
| 185 |
+
B, T = idx.shape
|
| 186 |
+
pos = torch.arange(T, device=idx.device)
|
| 187 |
+
x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
|
| 188 |
+
x = self.blocks(x)
|
| 189 |
+
x = self.ln_f(x)
|
| 190 |
+
logits = self.lm_head(x)
|
| 191 |
+
loss = None
|
| 192 |
+
if targets is not None:
|
| 193 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
| 194 |
+
return logits, loss
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def get_sparse_linears(model):
|
| 198 |
+
return [m for m in model.modules() if isinstance(m, SparseLinear)]
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class FastChunkScheduler:
|
| 202 |
+
def __init__(self, model, scheduler: Scheduler, target_fraction: float, chunk_size: int, device: str,
|
| 203 |
+
mass_beta: float = 0.95, similarity_history: int = 128, min_similarity_history: int = 8,
|
| 204 |
+
knn_k: int = 3):
|
| 205 |
+
self.scheduler = scheduler
|
| 206 |
+
self.target_fraction = target_fraction
|
| 207 |
+
self.chunk_size = chunk_size
|
| 208 |
+
self.device = device
|
| 209 |
+
self.mass_beta = mass_beta
|
| 210 |
+
self.similarity_history = similarity_history
|
| 211 |
+
self.min_similarity_history = min_similarity_history
|
| 212 |
+
self.knn_k = knn_k
|
| 213 |
+
self.linears = get_sparse_linears(model)
|
| 214 |
+
self.module_to_chunk_ids = {}
|
| 215 |
+
self.module_to_local_ids = {}
|
| 216 |
+
offset = 0
|
| 217 |
+
for m in self.linears:
|
| 218 |
+
m.chunk_size = chunk_size
|
| 219 |
+
assert m.out_features % chunk_size == 0
|
| 220 |
+
n_chunks = m.out_features // chunk_size
|
| 221 |
+
ids = torch.arange(offset, offset + n_chunks, device=device)
|
| 222 |
+
self.module_to_chunk_ids[m] = ids
|
| 223 |
+
self.module_to_local_ids[m] = torch.arange(n_chunks, device=device)
|
| 224 |
+
offset += n_chunks
|
| 225 |
+
self.n_chunks = offset
|
| 226 |
+
self.predicted_mass = torch.zeros(self.n_chunks, device=device)
|
| 227 |
+
self.mass_history = []
|
| 228 |
+
self.similarity = None
|
| 229 |
+
self.active_chunks = torch.zeros(self.n_chunks, dtype=torch.bool, device=device)
|
| 230 |
+
self.sensor_scores = torch.zeros(self.n_chunks, device=device)
|
| 231 |
+
|
| 232 |
+
def current_fraction(self, step: int, warmup_steps: int, anneal_steps: int) -> float:
|
| 233 |
+
if self.scheduler == "dense" or step < warmup_steps:
|
| 234 |
+
return 1.0
|
| 235 |
+
if anneal_steps > 0 and step < warmup_steps + anneal_steps:
|
| 236 |
+
progress = (step - warmup_steps) / anneal_steps
|
| 237 |
+
cosine_mult = 0.5 * (1.0 + math.cos(math.pi * progress))
|
| 238 |
+
return self.target_fraction + (1.0 - self.target_fraction) * cosine_mult
|
| 239 |
+
return self.target_fraction
|
| 240 |
+
|
| 241 |
+
def choose_active(self, step: int, warmup_steps: int, anneal_steps: int):
|
| 242 |
+
frac = self.current_fraction(step, warmup_steps, anneal_steps)
|
| 243 |
+
if frac >= 0.999 or self.scheduler == "dense":
|
| 244 |
+
self.active_chunks.fill_(True)
|
| 245 |
+
self.install_local_masks()
|
| 246 |
+
return self.active_chunks
|
| 247 |
+
k = max(1, int(frac * self.n_chunks))
|
| 248 |
+
self.active_chunks.fill_(False)
|
| 249 |
+
if self.scheduler == "random":
|
| 250 |
+
idx = torch.randperm(self.n_chunks, device=self.device)[:k]
|
| 251 |
+
elif self.scheduler == "ema_topk":
|
| 252 |
+
scores = self.predicted_mass + 1e-9 * torch.rand_like(self.predicted_mass)
|
| 253 |
+
idx = torch.topk(scores, k=k).indices
|
| 254 |
+
elif self.scheduler == "knn_scheduler":
|
| 255 |
+
base = self.sensor_scores if torch.count_nonzero(self.sensor_scores).item() else self.predicted_mass
|
| 256 |
+
scores = base + 1e-9 * torch.rand_like(base)
|
| 257 |
+
idx = torch.topk(scores, k=k).indices
|
| 258 |
+
else:
|
| 259 |
+
raise ValueError(f"Unknown scheduler: {self.scheduler}")
|
| 260 |
+
self.active_chunks[idx] = True
|
| 261 |
+
self.install_local_masks()
|
| 262 |
+
return self.active_chunks
|
| 263 |
+
|
| 264 |
+
def install_local_masks(self):
|
| 265 |
+
for m, global_ids in self.module_to_chunk_ids.items():
|
| 266 |
+
local = self.module_to_local_ids[m]
|
| 267 |
+
m.active_chunks = local[self.active_chunks[global_ids]]
|
| 268 |
+
|
| 269 |
+
@torch.no_grad()
|
| 270 |
+
def update_from_active_gradients(self, step: int, warmup_steps: int):
|
| 271 |
+
current_mass = torch.zeros_like(self.predicted_mass)
|
| 272 |
+
for m, ids in self.module_to_chunk_ids.items():
|
| 273 |
+
if m.weight.grad is None:
|
| 274 |
+
continue
|
| 275 |
+
w_sq = m.weight.grad.square().view(len(ids), self.chunk_size, -1).sum(dim=(1, 2))
|
| 276 |
+
if m.bias is not None and m.bias.grad is not None:
|
| 277 |
+
w_sq += m.bias.grad.square().view(len(ids), self.chunk_size).sum(dim=1)
|
| 278 |
+
current_mass[ids] = torch.sqrt(w_sq + 1e-30)
|
| 279 |
+
observed = self.active_chunks
|
| 280 |
+
never_seen = observed & (self.predicted_mass == 0)
|
| 281 |
+
already_seen = observed & ~never_seen
|
| 282 |
+
self.predicted_mass[never_seen] = current_mass[never_seen]
|
| 283 |
+
self.predicted_mass[already_seen] = self.mass_beta * self.predicted_mass[already_seen] + (1 - self.mass_beta) * current_mass[already_seen]
|
| 284 |
+
if step < warmup_steps:
|
| 285 |
+
self.mass_history.append(current_mass.detach().clone())
|
| 286 |
+
if len(self.mass_history) > self.similarity_history:
|
| 287 |
+
self.mass_history = self.mass_history[-self.similarity_history:]
|
| 288 |
+
if len(self.mass_history) >= self.min_similarity_history:
|
| 289 |
+
self.similarity = self.build_similarity()
|
| 290 |
+
self.sensor_scores = self.knn_scores(self.active_chunks, current_mass) if self.scheduler == "knn_scheduler" else self.predicted_mass.clone()
|
| 291 |
+
|
| 292 |
+
def build_similarity(self):
|
| 293 |
+
H = torch.stack(self.mass_history, dim=0)
|
| 294 |
+
H = H - H.mean(dim=0, keepdim=True)
|
| 295 |
+
H = H / (H.std(dim=0, keepdim=True) + 1e-6)
|
| 296 |
+
S = (H.T @ H) / max(1, H.shape[0] - 1)
|
| 297 |
+
S = torch.clamp(S, min=0.0)
|
| 298 |
+
S.fill_diagonal_(0.0)
|
| 299 |
+
allowed = torch.zeros_like(S, dtype=torch.bool)
|
| 300 |
+
for _, ids in self.module_to_chunk_ids.items():
|
| 301 |
+
allowed[ids[:, None], ids[None, :]] = True
|
| 302 |
+
return torch.where(allowed, S, torch.zeros_like(S))
|
| 303 |
+
|
| 304 |
+
def knn_scores(self, active_mask, current_mass):
|
| 305 |
+
if self.similarity is None:
|
| 306 |
+
return self.predicted_mass.clone()
|
| 307 |
+
scores = self.predicted_mass.clone()
|
| 308 |
+
scores[active_mask] = current_mass[active_mask]
|
| 309 |
+
active_idx = torch.nonzero(active_mask, as_tuple=False).flatten()
|
| 310 |
+
inactive_idx = torch.nonzero(~active_mask, as_tuple=False).flatten()
|
| 311 |
+
if active_idx.numel() == 0:
|
| 312 |
+
return scores
|
| 313 |
+
S = self.similarity
|
| 314 |
+
for i in inactive_idx.tolist():
|
| 315 |
+
weights = S[i, active_idx]
|
| 316 |
+
if weights.sum() <= 1e-12:
|
| 317 |
+
continue
|
| 318 |
+
kk = min(self.knn_k, weights.numel())
|
| 319 |
+
top = torch.topk(weights, k=kk)
|
| 320 |
+
w = top.values
|
| 321 |
+
aidx = active_idx[top.indices]
|
| 322 |
+
scores[i] = (w * current_mass[aidx]).sum() / (w.sum() + 1e-12)
|
| 323 |
+
return scores
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class ChunkedAdam:
|
| 327 |
+
def __init__(self, model, lr=3e-4, chunk_size=64):
|
| 328 |
+
self.model = model
|
| 329 |
+
self.lr = lr
|
| 330 |
+
self.chunk_size = chunk_size
|
| 331 |
+
self.state = {}
|
| 332 |
+
self.param_to_sparse_module = {}
|
| 333 |
+
for m in get_sparse_linears(model):
|
| 334 |
+
if m.weight is not None:
|
| 335 |
+
self.param_to_sparse_module[m.weight] = m
|
| 336 |
+
if m.bias is not None:
|
| 337 |
+
self.param_to_sparse_module[m.bias] = m
|
| 338 |
+
|
| 339 |
+
def zero_grad(self):
|
| 340 |
+
for p in self.model.parameters():
|
| 341 |
+
p.grad = None
|
| 342 |
+
|
| 343 |
+
@torch.no_grad()
|
| 344 |
+
def step(self):
|
| 345 |
+
for p in self.model.parameters():
|
| 346 |
+
if p.grad is None:
|
| 347 |
+
continue
|
| 348 |
+
if p not in self.state:
|
| 349 |
+
self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
|
| 350 |
+
exp_avg = self.state[p]["m"]
|
| 351 |
+
exp_avg_sq = self.state[p]["v"]
|
| 352 |
+
sparse_module = self.param_to_sparse_module.get(p)
|
| 353 |
+
active_chunks = getattr(sparse_module, "active_chunks", None) if sparse_module else None
|
| 354 |
+
if active_chunks is None:
|
| 355 |
+
exp_avg.mul_(0.9).add_(p.grad, alpha=0.1)
|
| 356 |
+
exp_avg_sq.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
|
| 357 |
+
p.sub_(exp_avg / (torch.sqrt(exp_avg_sq) + 1e-8), alpha=self.lr)
|
| 358 |
+
else:
|
| 359 |
+
for local_c in active_chunks.tolist():
|
| 360 |
+
start = int(local_c) * self.chunk_size
|
| 361 |
+
end = start + self.chunk_size
|
| 362 |
+
p_chunk = p[start:end]
|
| 363 |
+
g_chunk = p.grad[start:end]
|
| 364 |
+
m_chunk = exp_avg[start:end]
|
| 365 |
+
v_chunk = exp_avg_sq[start:end]
|
| 366 |
+
m_chunk.mul_(0.9).add_(g_chunk, alpha=0.1)
|
| 367 |
+
v_chunk.mul_(0.999).addcmul_(g_chunk, g_chunk, value=0.001)
|
| 368 |
+
p_chunk.sub_(m_chunk / (torch.sqrt(v_chunk) + 1e-8), alpha=self.lr)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def evaluate(model, corpus, batch_size, seed):
|
| 372 |
+
model.eval()
|
| 373 |
+
with torch.no_grad():
|
| 374 |
+
x, y = corpus.get_batch("val", batch_size, generator=make_cpu_generator(seed))
|
| 375 |
+
_, loss = model(x, y)
|
| 376 |
+
model.train()
|
| 377 |
+
return float(loss.item())
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def run_one(label, scheduler, mode, args):
|
| 381 |
+
set_seed(42)
|
| 382 |
+
corpus = CharCorpus(make_synthetic_corpus(), args.block_size, args.device)
|
| 383 |
+
model = MiniGPT(corpus.vocab_size, args.block_size, args.n_layer, args.n_head, args.n_embd, 0.0).to(args.device)
|
| 384 |
+
sched = FastChunkScheduler(model, scheduler, args.active_fraction, args.chunk_size, args.device)
|
| 385 |
+
opt = ChunkedAdam(model, lr=args.lr, chunk_size=args.chunk_size)
|
| 386 |
+
measured_steps = args.steps
|
| 387 |
+
if args.benchmark_sync:
|
| 388 |
+
sync_device(args.device)
|
| 389 |
+
t0 = time.perf_counter()
|
| 390 |
+
for step in range(args.steps):
|
| 391 |
+
if step == args.warmup_steps + args.anneal_steps:
|
| 392 |
+
if args.benchmark_sync:
|
| 393 |
+
sync_device(args.device)
|
| 394 |
+
t0 = time.perf_counter()
|
| 395 |
+
measured_steps = args.steps - step
|
| 396 |
+
if scheduler == "dense" or mode == "dense_baseline":
|
| 397 |
+
for m in get_sparse_linears(model):
|
| 398 |
+
m.sparse_enabled = False
|
| 399 |
+
m.active_chunks = None
|
| 400 |
+
else:
|
| 401 |
+
sched.choose_active(step, args.warmup_steps, args.anneal_steps)
|
| 402 |
+
for m in get_sparse_linears(model):
|
| 403 |
+
m.sparse_enabled = True
|
| 404 |
+
m.sparse_dx = mode == "sparse_dW_sparse_dX"
|
| 405 |
+
x, y = corpus.get_batch("train", args.batch_size, generator=make_cpu_generator(step))
|
| 406 |
+
opt.zero_grad()
|
| 407 |
+
_, loss = model(x, y)
|
| 408 |
+
loss.backward()
|
| 409 |
+
if scheduler != "dense" and mode != "dense_baseline":
|
| 410 |
+
sched.update_from_active_gradients(step, args.warmup_steps)
|
| 411 |
+
opt.step()
|
| 412 |
+
if args.benchmark_sync:
|
| 413 |
+
sync_device(args.device)
|
| 414 |
+
elapsed = time.perf_counter() - t0
|
| 415 |
+
return {"val": evaluate(model, corpus, args.batch_size, 12345), "ms": 1000 * elapsed / max(1, measured_steps)}
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def main():
|
| 419 |
+
parser = argparse.ArgumentParser()
|
| 420 |
+
parser.add_argument("--steps", type=int, default=500)
|
| 421 |
+
parser.add_argument("--batch_size", type=int, default=16)
|
| 422 |
+
parser.add_argument("--block_size", type=int, default=256)
|
| 423 |
+
parser.add_argument("--n_layer", type=int, default=4)
|
| 424 |
+
parser.add_argument("--n_head", type=int, default=16)
|
| 425 |
+
parser.add_argument("--n_embd", type=int, default=1024)
|
| 426 |
+
parser.add_argument("--chunk_size", type=int, default=64)
|
| 427 |
+
parser.add_argument("--active_fraction", type=float, default=0.10)
|
| 428 |
+
parser.add_argument("--warmup_steps", type=int, default=25)
|
| 429 |
+
parser.add_argument("--anneal_steps", type=int, default=150)
|
| 430 |
+
parser.add_argument("--lr", type=float, default=3e-4)
|
| 431 |
+
parser.add_argument("--device", type=str, default="mps")
|
| 432 |
+
parser.add_argument("--benchmark_sync", action="store_true")
|
| 433 |
+
args = parser.parse_args()
|
| 434 |
+
runs = [
|
| 435 |
+
("dense", "dense", "dense_baseline"),
|
| 436 |
+
("ema_full_dX", "ema_topk", "sparse_dW_full_dX"),
|
| 437 |
+
("knn_full_dX", "knn_scheduler", "sparse_dW_full_dX"),
|
| 438 |
+
("random_full_dX", "random", "sparse_dW_full_dX"),
|
| 439 |
+
("ema_sparse_dX", "ema_topk", "sparse_dW_sparse_dX"),
|
| 440 |
+
("knn_sparse_dX", "knn_scheduler", "sparse_dW_sparse_dX"),
|
| 441 |
+
("random_sparse_dX", "random", "sparse_dW_sparse_dX"),
|
| 442 |
+
]
|
| 443 |
+
print("\nFast chunked sparse backward with KNN scheduler")
|
| 444 |
+
print(f"device={args.device} steps={args.steps} d={args.n_embd} layers={args.n_layer}")
|
| 445 |
+
print(f"batch={args.batch_size} block={args.block_size} chunk={args.chunk_size}")
|
| 446 |
+
print(f"active={args.active_fraction} warmup={args.warmup_steps} anneal={args.anneal_steps}\n")
|
| 447 |
+
print(f"{'run':>18s} | {'val':>8s} | {'ms/step':>8s} | {'speedup':>8s}")
|
| 448 |
+
print("-" * 58)
|
| 449 |
+
dense_ms = None
|
| 450 |
+
for label, scheduler, mode in runs:
|
| 451 |
+
result = run_one(label, scheduler, mode, args)
|
| 452 |
+
if label == "dense":
|
| 453 |
+
dense_ms = result["ms"]
|
| 454 |
+
speedup = dense_ms / result["ms"] if dense_ms else float("nan")
|
| 455 |
+
print(f"{label:>18s} | {result['val']:8.4f} | {result['ms']:8.2f} | {speedup:8.3f}")
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
if __name__ == "__main__":
|
| 459 |
+
main()
|
sparse_transformer_v18_fast_knn_triton.py
ADDED
|
@@ -0,0 +1,1044 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Sparse Transformer v19: Triton-backed v18 KNN Scheduler.
|
| 4 |
+
|
| 5 |
+
This is the CUDA/Triton version of the fast chunked sparse-backward loop.
|
| 6 |
+
|
| 7 |
+
It combines:
|
| 8 |
+
- chunked sparse Linear backward
|
| 9 |
+
- Triton fused active-chunk dW + dBias
|
| 10 |
+
- optional Triton sparse dX
|
| 11 |
+
- EMA / KNN sensor scheduler / random support
|
| 12 |
+
- chunked sparse Adam update
|
| 13 |
+
|
| 14 |
+
Core modes:
|
| 15 |
+
dense
|
| 16 |
+
ema_full_dX
|
| 17 |
+
knn_full_dX
|
| 18 |
+
random_full_dX
|
| 19 |
+
ema_sparse_dX
|
| 20 |
+
knn_sparse_dX
|
| 21 |
+
random_sparse_dX
|
| 22 |
+
|
| 23 |
+
Safe mode:
|
| 24 |
+
knn_full_dX
|
| 25 |
+
Forward dense, dW/db sparse, dX full.
|
| 26 |
+
|
| 27 |
+
Aggressive mode:
|
| 28 |
+
knn_sparse_dX
|
| 29 |
+
Forward dense, dW/db sparse, dX sparse through active chunks.
|
| 30 |
+
|
| 31 |
+
Run:
|
| 32 |
+
python3 sparse_transformer_v19_triton_knn.py --device cuda --benchmark_sync
|
| 33 |
+
|
| 34 |
+
Useful:
|
| 35 |
+
python3 sparse_transformer_v19_triton_knn.py --device cuda --steps 500 --n_embd 1024 --benchmark_sync
|
| 36 |
+
python3 sparse_transformer_v19_triton_knn.py --device cuda --steps 500 --n_embd 2048 --benchmark_sync
|
| 37 |
+
|
| 38 |
+
Notes:
|
| 39 |
+
- This script needs CUDA + Triton.
|
| 40 |
+
- No autotune. Fixed configs reduce compile noise and keep comparisons stable.
|
| 41 |
+
- dW+dBias is fused.
|
| 42 |
+
- Uses block_ptr/tiled loads. On T4 this is not Hopper TMA; do not call it TMA.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
from __future__ import annotations
|
| 46 |
+
|
| 47 |
+
import argparse
|
| 48 |
+
import math
|
| 49 |
+
import random
|
| 50 |
+
import time
|
| 51 |
+
from typing import Dict, List, Literal, Optional, Tuple
|
| 52 |
+
|
| 53 |
+
import torch
|
| 54 |
+
|
| 55 |
+
torch.set_num_threads(1)
|
| 56 |
+
import torch.nn as nn
|
| 57 |
+
import torch.nn.functional as F
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
import triton
|
| 61 |
+
import triton.language as tl
|
| 62 |
+
TRITON_AVAILABLE = True
|
| 63 |
+
except Exception:
|
| 64 |
+
triton = None
|
| 65 |
+
tl = None
|
| 66 |
+
TRITON_AVAILABLE = False
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
Scheduler = Literal["dense", "ema_topk", "knn_scheduler", "random"]
|
| 70 |
+
BackwardMode = Literal["dense_baseline", "sparse_dW_full_dX", "sparse_dW_sparse_dX"]
|
| 71 |
+
KernelBackend = Literal["triton", "torch"]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# ================================================================
|
| 75 |
+
# Utilities
|
| 76 |
+
# ================================================================
|
| 77 |
+
|
| 78 |
+
def sync_device(device: str) -> None:
|
| 79 |
+
if device == "cuda" and torch.cuda.is_available():
|
| 80 |
+
torch.cuda.synchronize()
|
| 81 |
+
elif device == "mps" and hasattr(torch, "mps"):
|
| 82 |
+
torch.mps.synchronize()
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def set_seed(seed: int) -> None:
|
| 86 |
+
random.seed(seed)
|
| 87 |
+
torch.manual_seed(seed)
|
| 88 |
+
if torch.cuda.is_available():
|
| 89 |
+
torch.cuda.manual_seed_all(seed)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def make_cpu_generator(seed: int) -> torch.Generator:
|
| 93 |
+
gen = torch.Generator(device="cpu")
|
| 94 |
+
gen.manual_seed(seed)
|
| 95 |
+
return gen
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# ================================================================
|
| 99 |
+
# Data
|
| 100 |
+
# ================================================================
|
| 101 |
+
|
| 102 |
+
def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
|
| 103 |
+
rng = random.Random(seed)
|
| 104 |
+
words = [
|
| 105 |
+
"ada", "turing", "grace", "lovelace", "gradients",
|
| 106 |
+
"tokens", "circuits", "features", "boldly", "strangely",
|
| 107 |
+
"matrix", "attention", "kernel", "entropy", "signal",
|
| 108 |
+
]
|
| 109 |
+
return "\n".join(
|
| 110 |
+
" ".join(rng.choices(words, k=rng.randint(4, 10))) + "."
|
| 111 |
+
for _ in range(n_sentences)
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class CharCorpus:
|
| 116 |
+
def __init__(self, text: str, block_size: int, device: str):
|
| 117 |
+
chars = sorted(set(text))
|
| 118 |
+
self.stoi = {ch: i for i, ch in enumerate(chars)}
|
| 119 |
+
self.itos = {i: ch for ch, i in self.stoi.items()}
|
| 120 |
+
self.vocab_size = len(chars)
|
| 121 |
+
self.block_size = block_size
|
| 122 |
+
self.device = device
|
| 123 |
+
|
| 124 |
+
data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
|
| 125 |
+
self.train_data = data[: int(0.9 * len(data))]
|
| 126 |
+
self.val_data = data[int(0.9 * len(data)) :]
|
| 127 |
+
|
| 128 |
+
def get_batch(
|
| 129 |
+
self,
|
| 130 |
+
split: str,
|
| 131 |
+
batch_size: int,
|
| 132 |
+
generator: Optional[torch.Generator] = None,
|
| 133 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 134 |
+
data = self.train_data if split == "train" else self.val_data
|
| 135 |
+
ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator)
|
| 136 |
+
x = torch.stack([data[i : i + self.block_size] for i in ix])
|
| 137 |
+
y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
|
| 138 |
+
return x.to(self.device), y.to(self.device)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# ================================================================
|
| 142 |
+
# Triton sparse Linear backward kernels
|
| 143 |
+
# ================================================================
|
| 144 |
+
|
| 145 |
+
if TRITON_AVAILABLE:
|
| 146 |
+
@triton.jit
|
| 147 |
+
def _triton_sparse_bwd_dW_db_kernel(
|
| 148 |
+
X_ptr, dY_ptr, dW_ptr, dB_ptr, chunk_ids_ptr,
|
| 149 |
+
M: tl.constexpr, d_in: tl.constexpr, d_out: tl.constexpr, num_active: tl.constexpr,
|
| 150 |
+
stride_xm: tl.constexpr, stride_xk: tl.constexpr,
|
| 151 |
+
stride_dym: tl.constexpr, stride_dyn: tl.constexpr,
|
| 152 |
+
stride_dwn: tl.constexpr, stride_dwk: tl.constexpr,
|
| 153 |
+
HAS_BIAS: tl.constexpr,
|
| 154 |
+
CS: tl.constexpr, BK: tl.constexpr, BM: tl.constexpr,
|
| 155 |
+
):
|
| 156 |
+
"""
|
| 157 |
+
One program computes one [CS, BK] dW tile for one active chunk.
|
| 158 |
+
Bias for that chunk is fused into k_block_id == 0.
|
| 159 |
+
Grid: (num_active, ceil(d_in / BK))
|
| 160 |
+
"""
|
| 161 |
+
chunk_linear_id = tl.program_id(0)
|
| 162 |
+
k_block_id = tl.program_id(1)
|
| 163 |
+
|
| 164 |
+
chunk_idx = tl.load(chunk_ids_ptr + chunk_linear_id)
|
| 165 |
+
chunk_start = chunk_idx * CS
|
| 166 |
+
k_offset = k_block_id * BK
|
| 167 |
+
|
| 168 |
+
dy_block_ptr = tl.make_block_ptr(
|
| 169 |
+
base=dY_ptr,
|
| 170 |
+
shape=(d_out, M),
|
| 171 |
+
strides=(stride_dyn, stride_dym),
|
| 172 |
+
offsets=(chunk_start, 0),
|
| 173 |
+
block_shape=(CS, BM),
|
| 174 |
+
order=(1, 0),
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
x_block_ptr = tl.make_block_ptr(
|
| 178 |
+
base=X_ptr,
|
| 179 |
+
shape=(M, d_in),
|
| 180 |
+
strides=(stride_xm, stride_xk),
|
| 181 |
+
offsets=(0, k_offset),
|
| 182 |
+
block_shape=(BM, BK),
|
| 183 |
+
order=(1, 0),
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
acc_dw = tl.zeros((CS, BK), dtype=tl.float32)
|
| 187 |
+
compute_bias = HAS_BIAS and (k_block_id == 0)
|
| 188 |
+
acc_db = tl.zeros((CS,), dtype=tl.float32)
|
| 189 |
+
|
| 190 |
+
for _ in range(0, M, BM):
|
| 191 |
+
dy_t = tl.load(dy_block_ptr, boundary_check=(0, 1)) # [CS, BM]
|
| 192 |
+
x = tl.load(x_block_ptr, boundary_check=(0, 1)) # [BM, BK]
|
| 193 |
+
|
| 194 |
+
acc_dw = tl.dot(dy_t, x, acc=acc_dw)
|
| 195 |
+
|
| 196 |
+
if compute_bias:
|
| 197 |
+
acc_db += tl.sum(dy_t, axis=1)
|
| 198 |
+
|
| 199 |
+
dy_block_ptr = tl.advance(dy_block_ptr, (0, BM))
|
| 200 |
+
x_block_ptr = tl.advance(x_block_ptr, (BM, 0))
|
| 201 |
+
|
| 202 |
+
dw_block_ptr = tl.make_block_ptr(
|
| 203 |
+
base=dW_ptr,
|
| 204 |
+
shape=(d_out, d_in),
|
| 205 |
+
strides=(stride_dwn, stride_dwk),
|
| 206 |
+
offsets=(chunk_start, k_offset),
|
| 207 |
+
block_shape=(CS, BK),
|
| 208 |
+
order=(1, 0),
|
| 209 |
+
)
|
| 210 |
+
tl.store(dw_block_ptr, acc_dw.to(dW_ptr.dtype.element_ty), boundary_check=(0, 1))
|
| 211 |
+
|
| 212 |
+
if compute_bias:
|
| 213 |
+
rn = chunk_start + tl.arange(0, CS)
|
| 214 |
+
tl.store(dB_ptr + rn, acc_db.to(dB_ptr.dtype.element_ty), mask=rn < d_out)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
@triton.jit
|
| 218 |
+
def _triton_sparse_bwd_dX_kernel(
|
| 219 |
+
dY_ptr, W_ptr, dX_ptr, chunk_ids_ptr,
|
| 220 |
+
M: tl.constexpr, d_in: tl.constexpr, d_out: tl.constexpr, num_active: tl.constexpr,
|
| 221 |
+
stride_dym: tl.constexpr, stride_dyn: tl.constexpr,
|
| 222 |
+
stride_wn: tl.constexpr, stride_wk: tl.constexpr,
|
| 223 |
+
stride_dxm: tl.constexpr, stride_dxk: tl.constexpr,
|
| 224 |
+
CS: tl.constexpr, BM: tl.constexpr, BK: tl.constexpr,
|
| 225 |
+
):
|
| 226 |
+
"""
|
| 227 |
+
One program computes one [BM, BK] tile of dX by accumulating over active chunks.
|
| 228 |
+
Grid: (ceil(M/BM), ceil(d_in/BK)).
|
| 229 |
+
"""
|
| 230 |
+
pid_m = tl.program_id(0)
|
| 231 |
+
pid_k = tl.program_id(1)
|
| 232 |
+
|
| 233 |
+
m_offset = pid_m * BM
|
| 234 |
+
k_offset = pid_k * BK
|
| 235 |
+
|
| 236 |
+
acc = tl.zeros((BM, BK), dtype=tl.float32)
|
| 237 |
+
|
| 238 |
+
for i in range(0, num_active):
|
| 239 |
+
chunk_idx = tl.load(chunk_ids_ptr + i)
|
| 240 |
+
chunk_start = chunk_idx * CS
|
| 241 |
+
|
| 242 |
+
dy_block_ptr = tl.make_block_ptr(
|
| 243 |
+
base=dY_ptr,
|
| 244 |
+
shape=(M, d_out),
|
| 245 |
+
strides=(stride_dym, stride_dyn),
|
| 246 |
+
offsets=(m_offset, chunk_start),
|
| 247 |
+
block_shape=(BM, CS),
|
| 248 |
+
order=(1, 0),
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
w_block_ptr = tl.make_block_ptr(
|
| 252 |
+
base=W_ptr,
|
| 253 |
+
shape=(d_out, d_in),
|
| 254 |
+
strides=(stride_wn, stride_wk),
|
| 255 |
+
offsets=(chunk_start, k_offset),
|
| 256 |
+
block_shape=(CS, BK),
|
| 257 |
+
order=(1, 0),
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
dy = tl.load(dy_block_ptr, boundary_check=(0, 1)) # [BM, CS]
|
| 261 |
+
w = tl.load(w_block_ptr, boundary_check=(0, 1)) # [CS, BK]
|
| 262 |
+
acc = tl.dot(dy, w, acc=acc)
|
| 263 |
+
|
| 264 |
+
dx_block_ptr = tl.make_block_ptr(
|
| 265 |
+
base=dX_ptr,
|
| 266 |
+
shape=(M, d_in),
|
| 267 |
+
strides=(stride_dxm, stride_dxk),
|
| 268 |
+
offsets=(m_offset, k_offset),
|
| 269 |
+
block_shape=(BM, BK),
|
| 270 |
+
order=(1, 0),
|
| 271 |
+
)
|
| 272 |
+
tl.store(dx_block_ptr, acc.to(dX_ptr.dtype.element_ty), boundary_check=(0, 1))
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def triton_sparse_bwd_dW_db(
|
| 276 |
+
x_flat: torch.Tensor,
|
| 277 |
+
gy_flat: torch.Tensor,
|
| 278 |
+
active_chunks: torch.Tensor,
|
| 279 |
+
chunk_size: int,
|
| 280 |
+
d_out: int,
|
| 281 |
+
has_bias: bool,
|
| 282 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 283 |
+
if not TRITON_AVAILABLE:
|
| 284 |
+
raise RuntimeError("Triton is not available")
|
| 285 |
+
|
| 286 |
+
M, d_in = x_flat.shape
|
| 287 |
+
num_active = int(active_chunks.numel())
|
| 288 |
+
|
| 289 |
+
dW = torch.zeros((d_out, d_in), device=x_flat.device, dtype=x_flat.dtype)
|
| 290 |
+
dB = torch.zeros((d_out,), device=x_flat.device, dtype=x_flat.dtype) if has_bias else None
|
| 291 |
+
|
| 292 |
+
if num_active == 0:
|
| 293 |
+
return dW, dB
|
| 294 |
+
|
| 295 |
+
chunk_ids = active_chunks.to(torch.int32).contiguous()
|
| 296 |
+
|
| 297 |
+
# Fixed configs.
|
| 298 |
+
CS = int(chunk_size)
|
| 299 |
+
BK = 64
|
| 300 |
+
BM = 64
|
| 301 |
+
|
| 302 |
+
grid = (num_active, triton.cdiv(d_in, BK))
|
| 303 |
+
|
| 304 |
+
_triton_sparse_bwd_dW_db_kernel[grid](
|
| 305 |
+
x_flat, gy_flat, dW, dB if has_bias else dW, chunk_ids,
|
| 306 |
+
M, d_in, d_out, num_active,
|
| 307 |
+
x_flat.stride(0), x_flat.stride(1),
|
| 308 |
+
gy_flat.stride(0), gy_flat.stride(1),
|
| 309 |
+
dW.stride(0), dW.stride(1),
|
| 310 |
+
HAS_BIAS=has_bias,
|
| 311 |
+
CS=CS, BK=BK, BM=BM,
|
| 312 |
+
num_warps=4,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
return dW, dB
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def triton_sparse_bwd_dX(
|
| 319 |
+
gy_flat: torch.Tensor,
|
| 320 |
+
weight: torch.Tensor,
|
| 321 |
+
active_chunks: torch.Tensor,
|
| 322 |
+
chunk_size: int,
|
| 323 |
+
M: int,
|
| 324 |
+
d_in: int,
|
| 325 |
+
) -> torch.Tensor:
|
| 326 |
+
if not TRITON_AVAILABLE:
|
| 327 |
+
raise RuntimeError("Triton is not available")
|
| 328 |
+
|
| 329 |
+
num_active = int(active_chunks.numel())
|
| 330 |
+
d_out = gy_flat.shape[1]
|
| 331 |
+
dX = torch.zeros((M, d_in), device=gy_flat.device, dtype=gy_flat.dtype)
|
| 332 |
+
|
| 333 |
+
if num_active == 0:
|
| 334 |
+
return dX
|
| 335 |
+
|
| 336 |
+
chunk_ids = active_chunks.to(torch.int32).contiguous()
|
| 337 |
+
|
| 338 |
+
CS = int(chunk_size)
|
| 339 |
+
BM = 64
|
| 340 |
+
BK = 64
|
| 341 |
+
|
| 342 |
+
grid = (triton.cdiv(M, BM), triton.cdiv(d_in, BK))
|
| 343 |
+
|
| 344 |
+
_triton_sparse_bwd_dX_kernel[grid](
|
| 345 |
+
gy_flat, weight, dX, chunk_ids,
|
| 346 |
+
M, d_in, d_out, num_active,
|
| 347 |
+
gy_flat.stride(0), gy_flat.stride(1),
|
| 348 |
+
weight.stride(0), weight.stride(1),
|
| 349 |
+
dX.stride(0), dX.stride(1),
|
| 350 |
+
CS=CS, BM=BM, BK=BK,
|
| 351 |
+
num_warps=4,
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
return dX
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
# ================================================================
|
| 358 |
+
# Sparse Linear autograd
|
| 359 |
+
# ================================================================
|
| 360 |
+
|
| 361 |
+
class ChunkedMaskedLinearTorch(torch.autograd.Function):
|
| 362 |
+
@staticmethod
|
| 363 |
+
def forward(
|
| 364 |
+
ctx,
|
| 365 |
+
x: torch.Tensor,
|
| 366 |
+
weight: torch.Tensor,
|
| 367 |
+
bias: Optional[torch.Tensor],
|
| 368 |
+
active_chunks: torch.Tensor,
|
| 369 |
+
chunk_size: int,
|
| 370 |
+
sparse_dx: bool,
|
| 371 |
+
) -> torch.Tensor:
|
| 372 |
+
ctx.save_for_backward(x, weight, active_chunks)
|
| 373 |
+
ctx.has_bias = bias is not None
|
| 374 |
+
ctx.sparse_dx = bool(sparse_dx)
|
| 375 |
+
ctx.chunk_size = int(chunk_size)
|
| 376 |
+
return F.linear(x, weight, bias)
|
| 377 |
+
|
| 378 |
+
@staticmethod
|
| 379 |
+
def backward(ctx, grad_y: torch.Tensor):
|
| 380 |
+
x, weight, active_chunks = ctx.saved_tensors
|
| 381 |
+
chunk_size = ctx.chunk_size
|
| 382 |
+
|
| 383 |
+
x_flat = x.reshape(-1, x.shape[-1])
|
| 384 |
+
gy_flat = grad_y.reshape(-1, grad_y.shape[-1])
|
| 385 |
+
|
| 386 |
+
grad_w = torch.zeros_like(weight)
|
| 387 |
+
grad_b = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if ctx.has_bias else None
|
| 388 |
+
|
| 389 |
+
if ctx.sparse_dx:
|
| 390 |
+
grad_x_flat = torch.zeros_like(x_flat)
|
| 391 |
+
else:
|
| 392 |
+
grad_x_flat = gy_flat @ weight
|
| 393 |
+
|
| 394 |
+
for c_idx in active_chunks.tolist():
|
| 395 |
+
start = int(c_idx) * chunk_size
|
| 396 |
+
end = start + chunk_size
|
| 397 |
+
|
| 398 |
+
gy_slice = gy_flat[:, start:end]
|
| 399 |
+
w_slice = weight[start:end, :]
|
| 400 |
+
|
| 401 |
+
grad_w[start:end, :] = gy_slice.transpose(0, 1) @ x_flat
|
| 402 |
+
|
| 403 |
+
if grad_b is not None:
|
| 404 |
+
grad_b[start:end] = gy_slice.sum(dim=0)
|
| 405 |
+
|
| 406 |
+
if ctx.sparse_dx:
|
| 407 |
+
grad_x_flat += gy_slice @ w_slice
|
| 408 |
+
|
| 409 |
+
return grad_x_flat.reshape(x.shape), grad_w, grad_b, None, None, None
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class ChunkedMaskedLinearTriton(torch.autograd.Function):
|
| 413 |
+
@staticmethod
|
| 414 |
+
def forward(
|
| 415 |
+
ctx,
|
| 416 |
+
x: torch.Tensor,
|
| 417 |
+
weight: torch.Tensor,
|
| 418 |
+
bias: Optional[torch.Tensor],
|
| 419 |
+
active_chunks: torch.Tensor,
|
| 420 |
+
chunk_size: int,
|
| 421 |
+
sparse_dx: bool,
|
| 422 |
+
) -> torch.Tensor:
|
| 423 |
+
ctx.save_for_backward(x, weight, active_chunks)
|
| 424 |
+
ctx.has_bias = bias is not None
|
| 425 |
+
ctx.sparse_dx = bool(sparse_dx)
|
| 426 |
+
ctx.chunk_size = int(chunk_size)
|
| 427 |
+
return F.linear(x, weight, bias)
|
| 428 |
+
|
| 429 |
+
@staticmethod
|
| 430 |
+
def backward(ctx, grad_y: torch.Tensor):
|
| 431 |
+
x, weight, active_chunks = ctx.saved_tensors
|
| 432 |
+
chunk_size = ctx.chunk_size
|
| 433 |
+
|
| 434 |
+
x_shape = x.shape
|
| 435 |
+
x_flat = x.reshape(-1, x.shape[-1]).contiguous()
|
| 436 |
+
gy_flat = grad_y.reshape(-1, grad_y.shape[-1]).contiguous()
|
| 437 |
+
|
| 438 |
+
grad_w, grad_b = triton_sparse_bwd_dW_db(
|
| 439 |
+
x_flat=x_flat,
|
| 440 |
+
gy_flat=gy_flat,
|
| 441 |
+
active_chunks=active_chunks,
|
| 442 |
+
chunk_size=chunk_size,
|
| 443 |
+
d_out=weight.shape[0],
|
| 444 |
+
has_bias=ctx.has_bias,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
if ctx.sparse_dx:
|
| 448 |
+
grad_x_flat = triton_sparse_bwd_dX(
|
| 449 |
+
gy_flat=gy_flat,
|
| 450 |
+
weight=weight.contiguous(),
|
| 451 |
+
active_chunks=active_chunks,
|
| 452 |
+
chunk_size=chunk_size,
|
| 453 |
+
M=x_flat.shape[0],
|
| 454 |
+
d_in=x_flat.shape[1],
|
| 455 |
+
)
|
| 456 |
+
else:
|
| 457 |
+
grad_x_flat = gy_flat @ weight
|
| 458 |
+
|
| 459 |
+
grad_x = grad_x_flat.reshape(x_shape)
|
| 460 |
+
return grad_x, grad_w, grad_b, None, None, None
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
class SparseLinear(nn.Linear):
|
| 464 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
| 465 |
+
super().__init__(in_features, out_features, bias=bias)
|
| 466 |
+
self.sparse_enabled = False
|
| 467 |
+
self.sparse_dx = False
|
| 468 |
+
self.active_chunks: Optional[torch.Tensor] = None
|
| 469 |
+
self.chunk_size = 64
|
| 470 |
+
self.kernel_backend: KernelBackend = "triton"
|
| 471 |
+
|
| 472 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 473 |
+
if not self.sparse_enabled or self.active_chunks is None:
|
| 474 |
+
return F.linear(x, self.weight, self.bias)
|
| 475 |
+
|
| 476 |
+
if self.kernel_backend == "triton":
|
| 477 |
+
return ChunkedMaskedLinearTriton.apply(
|
| 478 |
+
x, self.weight, self.bias, self.active_chunks, self.chunk_size, self.sparse_dx
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
return ChunkedMaskedLinearTorch.apply(
|
| 482 |
+
x, self.weight, self.bias, self.active_chunks, self.chunk_size, self.sparse_dx
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
# ================================================================
|
| 487 |
+
# MiniGPT
|
| 488 |
+
# ================================================================
|
| 489 |
+
|
| 490 |
+
class CausalSelfAttention(nn.Module):
|
| 491 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 492 |
+
super().__init__()
|
| 493 |
+
assert n_embd % n_head == 0
|
| 494 |
+
self.n_head = n_head
|
| 495 |
+
self.head_dim = n_embd // n_head
|
| 496 |
+
self.c_attn = SparseLinear(n_embd, 3 * n_embd)
|
| 497 |
+
self.c_proj = SparseLinear(n_embd, n_embd)
|
| 498 |
+
self.dropout = nn.Dropout(dropout)
|
| 499 |
+
self.register_buffer(
|
| 500 |
+
"mask",
|
| 501 |
+
torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size),
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 505 |
+
B, T, C = x.shape
|
| 506 |
+
qkv = self.c_attn(x)
|
| 507 |
+
q, k, v = qkv.split(C, dim=2)
|
| 508 |
+
|
| 509 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 510 |
+
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 511 |
+
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 512 |
+
|
| 513 |
+
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 514 |
+
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
|
| 515 |
+
att = F.softmax(att, dim=-1)
|
| 516 |
+
att = self.dropout(att)
|
| 517 |
+
|
| 518 |
+
y = att @ v
|
| 519 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 520 |
+
return self.c_proj(y)
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
class FeedForward(nn.Module):
|
| 524 |
+
def __init__(self, n_embd: int, dropout: float):
|
| 525 |
+
super().__init__()
|
| 526 |
+
self.c_fc = SparseLinear(n_embd, 4 * n_embd)
|
| 527 |
+
self.c_proj = SparseLinear(4 * n_embd, n_embd)
|
| 528 |
+
self.dropout = nn.Dropout(dropout)
|
| 529 |
+
|
| 530 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 531 |
+
return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
class Block(nn.Module):
|
| 535 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
|
| 536 |
+
super().__init__()
|
| 537 |
+
self.ln1 = nn.LayerNorm(n_embd)
|
| 538 |
+
self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
|
| 539 |
+
self.ln2 = nn.LayerNorm(n_embd)
|
| 540 |
+
self.mlp = FeedForward(n_embd, dropout)
|
| 541 |
+
|
| 542 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 543 |
+
x = x + self.attn(self.ln1(x))
|
| 544 |
+
x = x + self.mlp(self.ln2(x))
|
| 545 |
+
return x
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
class MiniGPT(nn.Module):
|
| 549 |
+
def __init__(
|
| 550 |
+
self,
|
| 551 |
+
vocab_size: int,
|
| 552 |
+
block_size: int,
|
| 553 |
+
n_layer: int,
|
| 554 |
+
n_head: int,
|
| 555 |
+
n_embd: int,
|
| 556 |
+
dropout: float,
|
| 557 |
+
):
|
| 558 |
+
super().__init__()
|
| 559 |
+
self.block_size = block_size
|
| 560 |
+
self.tok_emb = nn.Embedding(vocab_size, n_embd)
|
| 561 |
+
self.pos_emb = nn.Embedding(block_size, n_embd)
|
| 562 |
+
self.blocks = nn.Sequential(
|
| 563 |
+
*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]
|
| 564 |
+
)
|
| 565 |
+
self.ln_f = nn.LayerNorm(n_embd)
|
| 566 |
+
self.lm_head = nn.Linear(n_embd, vocab_size)
|
| 567 |
+
|
| 568 |
+
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
|
| 569 |
+
B, T = idx.shape
|
| 570 |
+
pos = torch.arange(T, device=idx.device)
|
| 571 |
+
x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
|
| 572 |
+
x = self.blocks(x)
|
| 573 |
+
x = self.ln_f(x)
|
| 574 |
+
logits = self.lm_head(x)
|
| 575 |
+
|
| 576 |
+
loss = None
|
| 577 |
+
if targets is not None:
|
| 578 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
| 579 |
+
return logits, loss
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
def get_sparse_linears(model: nn.Module) -> List[SparseLinear]:
|
| 583 |
+
return [m for m in model.modules() if isinstance(m, SparseLinear)]
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
# ================================================================
|
| 587 |
+
# Scheduler
|
| 588 |
+
# ================================================================
|
| 589 |
+
|
| 590 |
+
class FastChunkScheduler:
|
| 591 |
+
def __init__(
|
| 592 |
+
self,
|
| 593 |
+
model: nn.Module,
|
| 594 |
+
scheduler: Scheduler,
|
| 595 |
+
target_fraction: float,
|
| 596 |
+
chunk_size: int,
|
| 597 |
+
device: str,
|
| 598 |
+
mass_beta: float = 0.95,
|
| 599 |
+
similarity_history: int = 128,
|
| 600 |
+
min_similarity_history: int = 8,
|
| 601 |
+
knn_k: int = 3,
|
| 602 |
+
):
|
| 603 |
+
self.model = model
|
| 604 |
+
self.scheduler = scheduler
|
| 605 |
+
self.target_fraction = target_fraction
|
| 606 |
+
self.chunk_size = chunk_size
|
| 607 |
+
self.device = device
|
| 608 |
+
self.mass_beta = mass_beta
|
| 609 |
+
self.similarity_history = similarity_history
|
| 610 |
+
self.min_similarity_history = min_similarity_history
|
| 611 |
+
self.knn_k = knn_k
|
| 612 |
+
|
| 613 |
+
self.linears = get_sparse_linears(model)
|
| 614 |
+
self.module_to_chunk_ids: Dict[nn.Module, torch.Tensor] = {}
|
| 615 |
+
self.module_to_local_ids: Dict[nn.Module, torch.Tensor] = {}
|
| 616 |
+
|
| 617 |
+
offset = 0
|
| 618 |
+
for m in self.linears:
|
| 619 |
+
m.chunk_size = chunk_size
|
| 620 |
+
n_chunks = m.out_features // chunk_size
|
| 621 |
+
assert m.out_features % chunk_size == 0, (
|
| 622 |
+
f"out_features {m.out_features} not divisible by chunk_size {chunk_size}"
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
ids = torch.arange(offset, offset + n_chunks, device=device)
|
| 626 |
+
local = torch.arange(n_chunks, device=device)
|
| 627 |
+
|
| 628 |
+
self.module_to_chunk_ids[m] = ids
|
| 629 |
+
self.module_to_local_ids[m] = local
|
| 630 |
+
offset += n_chunks
|
| 631 |
+
|
| 632 |
+
self.n_chunks = offset
|
| 633 |
+
self.predicted_mass = torch.zeros(self.n_chunks, device=device)
|
| 634 |
+
self.mass_history: List[torch.Tensor] = []
|
| 635 |
+
self.similarity: Optional[torch.Tensor] = None
|
| 636 |
+
|
| 637 |
+
self.active_chunks = torch.zeros(self.n_chunks, dtype=torch.bool, device=device)
|
| 638 |
+
self.sensor_scores = torch.zeros(self.n_chunks, device=device)
|
| 639 |
+
|
| 640 |
+
def current_fraction(self, step: int, warmup_steps: int, anneal_steps: int) -> float:
|
| 641 |
+
if self.scheduler == "dense":
|
| 642 |
+
return 1.0
|
| 643 |
+
if step < warmup_steps:
|
| 644 |
+
return 1.0
|
| 645 |
+
if anneal_steps > 0 and step < warmup_steps + anneal_steps:
|
| 646 |
+
progress = (step - warmup_steps) / anneal_steps
|
| 647 |
+
cosine_mult = 0.5 * (1.0 + math.cos(math.pi * progress))
|
| 648 |
+
return self.target_fraction + (1.0 - self.target_fraction) * cosine_mult
|
| 649 |
+
return self.target_fraction
|
| 650 |
+
|
| 651 |
+
def choose_active(self, step: int, warmup_steps: int, anneal_steps: int) -> torch.Tensor:
|
| 652 |
+
frac = self.current_fraction(step, warmup_steps, anneal_steps)
|
| 653 |
+
|
| 654 |
+
if frac >= 0.999 or self.scheduler == "dense":
|
| 655 |
+
self.active_chunks.fill_(True)
|
| 656 |
+
self.install_local_masks()
|
| 657 |
+
return self.active_chunks
|
| 658 |
+
|
| 659 |
+
k = max(1, int(frac * self.n_chunks))
|
| 660 |
+
self.active_chunks.fill_(False)
|
| 661 |
+
|
| 662 |
+
if self.scheduler == "random":
|
| 663 |
+
idx = torch.randperm(self.n_chunks, device=self.device)[:k]
|
| 664 |
+
|
| 665 |
+
elif self.scheduler == "ema_topk":
|
| 666 |
+
scores = self.predicted_mass + 1e-9 * torch.rand_like(self.predicted_mass)
|
| 667 |
+
idx = torch.topk(scores, k=k).indices
|
| 668 |
+
|
| 669 |
+
elif self.scheduler == "knn_scheduler":
|
| 670 |
+
base = self.sensor_scores
|
| 671 |
+
if torch.count_nonzero(base).item() == 0:
|
| 672 |
+
base = self.predicted_mass
|
| 673 |
+
scores = base + 1e-9 * torch.rand_like(base)
|
| 674 |
+
idx = torch.topk(scores, k=k).indices
|
| 675 |
+
|
| 676 |
+
else:
|
| 677 |
+
raise ValueError(f"Unknown scheduler: {self.scheduler}")
|
| 678 |
+
|
| 679 |
+
self.active_chunks[idx] = True
|
| 680 |
+
self.install_local_masks()
|
| 681 |
+
return self.active_chunks
|
| 682 |
+
|
| 683 |
+
def install_local_masks(self) -> None:
|
| 684 |
+
for m, global_ids in self.module_to_chunk_ids.items():
|
| 685 |
+
local = self.module_to_local_ids[m]
|
| 686 |
+
m.active_chunks = local[self.active_chunks[global_ids]]
|
| 687 |
+
|
| 688 |
+
@torch.no_grad()
|
| 689 |
+
def update_from_active_gradients(self, step: int, warmup_steps: int) -> torch.Tensor:
|
| 690 |
+
current_mass = torch.zeros_like(self.predicted_mass)
|
| 691 |
+
|
| 692 |
+
for m, ids in self.module_to_chunk_ids.items():
|
| 693 |
+
if m.weight.grad is None:
|
| 694 |
+
continue
|
| 695 |
+
|
| 696 |
+
w_sq = m.weight.grad.square().view(len(ids), self.chunk_size, -1).sum(dim=(1, 2))
|
| 697 |
+
if m.bias is not None and m.bias.grad is not None:
|
| 698 |
+
w_sq += m.bias.grad.square().view(len(ids), self.chunk_size).sum(dim=1)
|
| 699 |
+
|
| 700 |
+
current_mass[ids] = torch.sqrt(w_sq + 1e-30)
|
| 701 |
+
|
| 702 |
+
observed = self.active_chunks
|
| 703 |
+
never_seen = observed & (self.predicted_mass == 0)
|
| 704 |
+
already_seen = observed & ~never_seen
|
| 705 |
+
|
| 706 |
+
self.predicted_mass[never_seen] = current_mass[never_seen]
|
| 707 |
+
self.predicted_mass[already_seen] = (
|
| 708 |
+
self.mass_beta * self.predicted_mass[already_seen]
|
| 709 |
+
+ (1.0 - self.mass_beta) * current_mass[already_seen]
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
if step < warmup_steps:
|
| 713 |
+
self.mass_history.append(current_mass.detach().clone())
|
| 714 |
+
if len(self.mass_history) > self.similarity_history:
|
| 715 |
+
self.mass_history = self.mass_history[-self.similarity_history :]
|
| 716 |
+
if len(self.mass_history) >= self.min_similarity_history:
|
| 717 |
+
self.similarity = self.build_similarity()
|
| 718 |
+
|
| 719 |
+
if self.scheduler == "knn_scheduler":
|
| 720 |
+
self.sensor_scores = self.knn_scores(self.active_chunks, current_mass)
|
| 721 |
+
else:
|
| 722 |
+
self.sensor_scores = self.predicted_mass.clone()
|
| 723 |
+
|
| 724 |
+
return current_mass
|
| 725 |
+
|
| 726 |
+
def build_similarity(self) -> torch.Tensor:
|
| 727 |
+
H = torch.stack(self.mass_history, dim=0)
|
| 728 |
+
H = H - H.mean(dim=0, keepdim=True)
|
| 729 |
+
H = H / (H.std(dim=0, keepdim=True) + 1e-6)
|
| 730 |
+
|
| 731 |
+
S = (H.T @ H) / max(1, H.shape[0] - 1)
|
| 732 |
+
S = torch.clamp(S, min=0.0)
|
| 733 |
+
S.fill_diagonal_(0.0)
|
| 734 |
+
|
| 735 |
+
allowed = torch.zeros_like(S, dtype=torch.bool)
|
| 736 |
+
for _, ids in self.module_to_chunk_ids.items():
|
| 737 |
+
allowed[ids[:, None], ids[None, :]] = True
|
| 738 |
+
|
| 739 |
+
return torch.where(allowed, S, torch.zeros_like(S))
|
| 740 |
+
|
| 741 |
+
def knn_scores(self, active_mask: torch.Tensor, current_mass: torch.Tensor) -> torch.Tensor:
|
| 742 |
+
if self.similarity is None:
|
| 743 |
+
return self.predicted_mass.clone()
|
| 744 |
+
|
| 745 |
+
scores = self.predicted_mass.clone()
|
| 746 |
+
scores[active_mask] = current_mass[active_mask]
|
| 747 |
+
|
| 748 |
+
active_idx = torch.nonzero(active_mask, as_tuple=False).flatten()
|
| 749 |
+
inactive_idx = torch.nonzero(~active_mask, as_tuple=False).flatten()
|
| 750 |
+
|
| 751 |
+
if active_idx.numel() == 0:
|
| 752 |
+
return scores
|
| 753 |
+
|
| 754 |
+
S = self.similarity
|
| 755 |
+
for i in inactive_idx.tolist():
|
| 756 |
+
weights = S[i, active_idx]
|
| 757 |
+
if weights.sum() <= 1e-12:
|
| 758 |
+
continue
|
| 759 |
+
|
| 760 |
+
kk = min(self.knn_k, weights.numel())
|
| 761 |
+
top = torch.topk(weights, k=kk)
|
| 762 |
+
w = top.values
|
| 763 |
+
aidx = active_idx[top.indices]
|
| 764 |
+
scores[i] = (w * current_mass[aidx]).sum() / (w.sum() + 1e-12)
|
| 765 |
+
|
| 766 |
+
return scores
|
| 767 |
+
|
| 768 |
+
|
| 769 |
+
# ================================================================
|
| 770 |
+
# Chunked Adam
|
| 771 |
+
# ================================================================
|
| 772 |
+
|
| 773 |
+
class ChunkedAdam:
|
| 774 |
+
def __init__(self, model: nn.Module, lr: float = 3e-4, chunk_size: int = 64):
|
| 775 |
+
self.model = model
|
| 776 |
+
self.lr = lr
|
| 777 |
+
self.chunk_size = chunk_size
|
| 778 |
+
self.state: Dict[torch.nn.Parameter, Dict[str, torch.Tensor]] = {}
|
| 779 |
+
|
| 780 |
+
self.param_to_sparse_module: Dict[torch.nn.Parameter, SparseLinear] = {}
|
| 781 |
+
for m in get_sparse_linears(model):
|
| 782 |
+
if m.weight is not None:
|
| 783 |
+
self.param_to_sparse_module[m.weight] = m
|
| 784 |
+
if m.bias is not None:
|
| 785 |
+
self.param_to_sparse_module[m.bias] = m
|
| 786 |
+
|
| 787 |
+
def zero_grad(self):
|
| 788 |
+
for p in self.model.parameters():
|
| 789 |
+
p.grad = None
|
| 790 |
+
|
| 791 |
+
@torch.no_grad()
|
| 792 |
+
def step(self):
|
| 793 |
+
for p in self.model.parameters():
|
| 794 |
+
if p.grad is None:
|
| 795 |
+
continue
|
| 796 |
+
|
| 797 |
+
if p not in self.state:
|
| 798 |
+
self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
|
| 799 |
+
|
| 800 |
+
exp_avg = self.state[p]["m"]
|
| 801 |
+
exp_avg_sq = self.state[p]["v"]
|
| 802 |
+
|
| 803 |
+
sparse_module = self.param_to_sparse_module.get(p)
|
| 804 |
+
active_chunks = getattr(sparse_module, "active_chunks", None) if sparse_module else None
|
| 805 |
+
|
| 806 |
+
if active_chunks is None:
|
| 807 |
+
exp_avg.mul_(0.9).add_(p.grad, alpha=0.1)
|
| 808 |
+
exp_avg_sq.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
|
| 809 |
+
p.sub_(exp_avg / (torch.sqrt(exp_avg_sq) + 1e-8), alpha=self.lr)
|
| 810 |
+
else:
|
| 811 |
+
for local_c in active_chunks.tolist():
|
| 812 |
+
start = int(local_c) * self.chunk_size
|
| 813 |
+
end = start + self.chunk_size
|
| 814 |
+
|
| 815 |
+
p_chunk = p[start:end]
|
| 816 |
+
g_chunk = p.grad[start:end]
|
| 817 |
+
m_chunk = exp_avg[start:end]
|
| 818 |
+
v_chunk = exp_avg_sq[start:end]
|
| 819 |
+
|
| 820 |
+
m_chunk.mul_(0.9).add_(g_chunk, alpha=0.1)
|
| 821 |
+
v_chunk.mul_(0.999).addcmul_(g_chunk, g_chunk, value=0.001)
|
| 822 |
+
p_chunk.sub_(m_chunk / (torch.sqrt(v_chunk) + 1e-8), alpha=self.lr)
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
# ================================================================
|
| 826 |
+
# Training
|
| 827 |
+
# ================================================================
|
| 828 |
+
|
| 829 |
+
def evaluate(model: nn.Module, corpus: CharCorpus, batch_size: int, seed: int) -> float:
|
| 830 |
+
model.eval()
|
| 831 |
+
with torch.no_grad():
|
| 832 |
+
x, y = corpus.get_batch("val", batch_size, generator=make_cpu_generator(seed))
|
| 833 |
+
_, loss = model(x, y)
|
| 834 |
+
model.train()
|
| 835 |
+
return float(loss.item())
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
def run_one(
|
| 839 |
+
scheduler_name: Scheduler,
|
| 840 |
+
mode: BackwardMode,
|
| 841 |
+
kernel_backend: KernelBackend,
|
| 842 |
+
device: str,
|
| 843 |
+
steps: int,
|
| 844 |
+
batch_size: int,
|
| 845 |
+
block_size: int,
|
| 846 |
+
n_layer: int,
|
| 847 |
+
n_head: int,
|
| 848 |
+
n_embd: int,
|
| 849 |
+
chunk_size: int,
|
| 850 |
+
active_fraction: float,
|
| 851 |
+
warmup_steps: int,
|
| 852 |
+
anneal_steps: int,
|
| 853 |
+
benchmark_sync: bool,
|
| 854 |
+
) -> Dict[str, float]:
|
| 855 |
+
set_seed(42)
|
| 856 |
+
|
| 857 |
+
corpus = CharCorpus(make_synthetic_corpus(), block_size, device)
|
| 858 |
+
model = MiniGPT(corpus.vocab_size, block_size, n_layer, n_head, n_embd, 0.0).to(device)
|
| 859 |
+
|
| 860 |
+
for m in get_sparse_linears(model):
|
| 861 |
+
m.chunk_size = chunk_size
|
| 862 |
+
m.kernel_backend = kernel_backend
|
| 863 |
+
|
| 864 |
+
sched = FastChunkScheduler(
|
| 865 |
+
model=model,
|
| 866 |
+
scheduler=scheduler_name,
|
| 867 |
+
target_fraction=active_fraction,
|
| 868 |
+
chunk_size=chunk_size,
|
| 869 |
+
device=device,
|
| 870 |
+
)
|
| 871 |
+
opt = ChunkedAdam(model, lr=3e-4, chunk_size=chunk_size)
|
| 872 |
+
|
| 873 |
+
measured_steps = steps
|
| 874 |
+
|
| 875 |
+
if benchmark_sync:
|
| 876 |
+
sync_device(device)
|
| 877 |
+
t0 = time.perf_counter()
|
| 878 |
+
|
| 879 |
+
for step in range(steps):
|
| 880 |
+
if step == warmup_steps + anneal_steps:
|
| 881 |
+
if benchmark_sync:
|
| 882 |
+
sync_device(device)
|
| 883 |
+
t0 = time.perf_counter()
|
| 884 |
+
measured_steps = steps - step
|
| 885 |
+
|
| 886 |
+
if scheduler_name == "dense" or mode == "dense_baseline":
|
| 887 |
+
for m in get_sparse_linears(model):
|
| 888 |
+
m.sparse_enabled = False
|
| 889 |
+
m.active_chunks = None
|
| 890 |
+
else:
|
| 891 |
+
sched.choose_active(step, warmup_steps=warmup_steps, anneal_steps=anneal_steps)
|
| 892 |
+
for m in get_sparse_linears(model):
|
| 893 |
+
m.sparse_enabled = True
|
| 894 |
+
m.sparse_dx = mode == "sparse_dW_sparse_dX"
|
| 895 |
+
|
| 896 |
+
x, y = corpus.get_batch("train", batch_size, generator=make_cpu_generator(step))
|
| 897 |
+
|
| 898 |
+
opt.zero_grad()
|
| 899 |
+
_, loss = model(x, y)
|
| 900 |
+
loss.backward()
|
| 901 |
+
|
| 902 |
+
if scheduler_name != "dense" and mode != "dense_baseline":
|
| 903 |
+
sched.update_from_active_gradients(step=step, warmup_steps=warmup_steps)
|
| 904 |
+
|
| 905 |
+
opt.step()
|
| 906 |
+
|
| 907 |
+
if benchmark_sync:
|
| 908 |
+
sync_device(device)
|
| 909 |
+
elapsed = time.perf_counter() - t0
|
| 910 |
+
|
| 911 |
+
val_loss = evaluate(model, corpus, batch_size, seed=12345)
|
| 912 |
+
return {"val": val_loss, "ms": 1000.0 * elapsed / max(1, measured_steps)}
|
| 913 |
+
|
| 914 |
+
|
| 915 |
+
def run_correctness_smoke(
|
| 916 |
+
device: str,
|
| 917 |
+
chunk_size: int = 64,
|
| 918 |
+
dtype: torch.dtype = torch.float32,
|
| 919 |
+
) -> None:
|
| 920 |
+
if device != "cuda":
|
| 921 |
+
print("Skipping Triton correctness smoke: device is not cuda")
|
| 922 |
+
return
|
| 923 |
+
if not TRITON_AVAILABLE:
|
| 924 |
+
print("Skipping Triton correctness smoke: Triton not available")
|
| 925 |
+
return
|
| 926 |
+
|
| 927 |
+
print("\nTriton sparse Linear correctness smoke")
|
| 928 |
+
print("-" * 60)
|
| 929 |
+
|
| 930 |
+
torch.manual_seed(123)
|
| 931 |
+
shapes = [(512, 2048), (1024, 4096)]
|
| 932 |
+
|
| 933 |
+
for d_in, d_out in shapes:
|
| 934 |
+
M = 512
|
| 935 |
+
n_chunks = d_out // chunk_size
|
| 936 |
+
n_active = max(1, int(0.1 * n_chunks))
|
| 937 |
+
active = torch.randperm(n_chunks, device=device)[:n_active].sort().values
|
| 938 |
+
|
| 939 |
+
x = torch.randn(M, d_in, device=device, dtype=dtype)
|
| 940 |
+
w = torch.randn(d_out, d_in, device=device, dtype=dtype)
|
| 941 |
+
gy = torch.randn(M, d_out, device=device, dtype=dtype)
|
| 942 |
+
|
| 943 |
+
ref_dw = torch.zeros_like(w)
|
| 944 |
+
ref_db = torch.zeros(d_out, device=device, dtype=dtype)
|
| 945 |
+
ref_dx = torch.zeros_like(x)
|
| 946 |
+
|
| 947 |
+
for c in active.tolist():
|
| 948 |
+
s = c * chunk_size
|
| 949 |
+
e = s + chunk_size
|
| 950 |
+
ref_dw[s:e] = gy[:, s:e].transpose(0, 1) @ x
|
| 951 |
+
ref_db[s:e] = gy[:, s:e].sum(0)
|
| 952 |
+
ref_dx += gy[:, s:e] @ w[s:e]
|
| 953 |
+
|
| 954 |
+
tri_dw, tri_db = triton_sparse_bwd_dW_db(x, gy, active, chunk_size, d_out, True)
|
| 955 |
+
tri_dx = triton_sparse_bwd_dX(gy, w, active, chunk_size, M, d_in)
|
| 956 |
+
|
| 957 |
+
dw_err = float((tri_dw - ref_dw).abs().max().item())
|
| 958 |
+
db_err = float((tri_db - ref_db).abs().max().item())
|
| 959 |
+
dx_err = float((tri_dx - ref_dx).abs().max().item())
|
| 960 |
+
|
| 961 |
+
print(f"d_in={d_in:4d} d_out={d_out:5d}: dW={dw_err:.6f} dB={db_err:.6f} dX={dx_err:.6f}")
|
| 962 |
+
|
| 963 |
+
|
| 964 |
+
def main() -> None:
|
| 965 |
+
parser = argparse.ArgumentParser()
|
| 966 |
+
parser.add_argument("--steps", type=int, default=500)
|
| 967 |
+
parser.add_argument("--batch_size", type=int, default=16)
|
| 968 |
+
parser.add_argument("--block_size", type=int, default=256)
|
| 969 |
+
parser.add_argument("--n_layer", type=int, default=4)
|
| 970 |
+
parser.add_argument("--n_head", type=int, default=16)
|
| 971 |
+
parser.add_argument("--n_embd", type=int, default=1024)
|
| 972 |
+
parser.add_argument("--chunk_size", type=int, default=64)
|
| 973 |
+
parser.add_argument("--active_fraction", type=float, default=0.10)
|
| 974 |
+
parser.add_argument("--warmup_steps", type=int, default=25)
|
| 975 |
+
parser.add_argument("--anneal_steps", type=int, default=150)
|
| 976 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 977 |
+
parser.add_argument("--kernel_backend", type=str, default="triton", choices=["triton", "torch"])
|
| 978 |
+
parser.add_argument("--benchmark_sync", action="store_true")
|
| 979 |
+
parser.add_argument("--skip_correctness", action="store_true")
|
| 980 |
+
args = parser.parse_args()
|
| 981 |
+
|
| 982 |
+
if args.kernel_backend == "triton":
|
| 983 |
+
if args.device != "cuda":
|
| 984 |
+
raise RuntimeError("--kernel_backend triton requires --device cuda")
|
| 985 |
+
if not TRITON_AVAILABLE:
|
| 986 |
+
raise RuntimeError("Triton is not available")
|
| 987 |
+
|
| 988 |
+
if not args.skip_correctness:
|
| 989 |
+
run_correctness_smoke(device=args.device, chunk_size=args.chunk_size)
|
| 990 |
+
|
| 991 |
+
runs: List[Tuple[str, Scheduler, BackwardMode]] = [
|
| 992 |
+
("dense", "dense", "dense_baseline"),
|
| 993 |
+
("ema_full_dX", "ema_topk", "sparse_dW_full_dX"),
|
| 994 |
+
("knn_full_dX", "knn_scheduler", "sparse_dW_full_dX"),
|
| 995 |
+
("random_full_dX", "random", "sparse_dW_full_dX"),
|
| 996 |
+
("ema_sparse_dX", "ema_topk", "sparse_dW_sparse_dX"),
|
| 997 |
+
("knn_sparse_dX", "knn_scheduler", "sparse_dW_sparse_dX"),
|
| 998 |
+
("random_sparse_dX", "random", "sparse_dW_sparse_dX"),
|
| 999 |
+
]
|
| 1000 |
+
|
| 1001 |
+
print("\nTriton-backed fast chunked sparse backward with KNN scheduler")
|
| 1002 |
+
print(f"device={args.device} backend={args.kernel_backend} triton_available={TRITON_AVAILABLE}")
|
| 1003 |
+
print(f"steps={args.steps} d={args.n_embd} layers={args.n_layer}")
|
| 1004 |
+
print(f"batch={args.batch_size} block={args.block_size} chunk={args.chunk_size}")
|
| 1005 |
+
print(f"active={args.active_fraction} warmup={args.warmup_steps} anneal={args.anneal_steps}\n")
|
| 1006 |
+
print(f"{'run':>18s} | {'val':>8s} | {'ms/step':>8s} | {'speedup':>8s}")
|
| 1007 |
+
print("-" * 58)
|
| 1008 |
+
|
| 1009 |
+
dense_ms: Optional[float] = None
|
| 1010 |
+
|
| 1011 |
+
for label, scheduler, mode in runs:
|
| 1012 |
+
result = run_one(
|
| 1013 |
+
scheduler_name=scheduler,
|
| 1014 |
+
mode=mode,
|
| 1015 |
+
kernel_backend=args.kernel_backend, # type: ignore[arg-type]
|
| 1016 |
+
device=args.device,
|
| 1017 |
+
steps=args.steps,
|
| 1018 |
+
batch_size=args.batch_size,
|
| 1019 |
+
block_size=args.block_size,
|
| 1020 |
+
n_layer=args.n_layer,
|
| 1021 |
+
n_head=args.n_head,
|
| 1022 |
+
n_embd=args.n_embd,
|
| 1023 |
+
chunk_size=args.chunk_size,
|
| 1024 |
+
active_fraction=args.active_fraction,
|
| 1025 |
+
warmup_steps=args.warmup_steps,
|
| 1026 |
+
anneal_steps=args.anneal_steps,
|
| 1027 |
+
benchmark_sync=args.benchmark_sync,
|
| 1028 |
+
)
|
| 1029 |
+
|
| 1030 |
+
if label == "dense":
|
| 1031 |
+
dense_ms = result["ms"]
|
| 1032 |
+
|
| 1033 |
+
speedup = dense_ms / result["ms"] if dense_ms is not None else float("nan")
|
| 1034 |
+
|
| 1035 |
+
print(
|
| 1036 |
+
f"{label:>18s} | "
|
| 1037 |
+
f"{result['val']:8.4f} | "
|
| 1038 |
+
f"{result['ms']:8.2f} | "
|
| 1039 |
+
f"{speedup:8.3f}"
|
| 1040 |
+
)
|
| 1041 |
+
|
| 1042 |
+
|
| 1043 |
+
if __name__ == "__main__":
|
| 1044 |
+
main()
|