theapemachine's picture
Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.
bc1b8eb
#!/usr/bin/env python3
"""
Triton-fused Chunked Sparse Backward Pass.
Replaces the Python for-loop over active chunks with fused Triton kernels:
1. sparse_bwd_dW: grad_W[c*CS:(c+1)*CS, :] = grad_Y[:, c*CS:(c+1)*CS].T @ X for active c
2. sparse_bwd_dX: grad_X += grad_Y[:, c*CS:(c+1)*CS] @ W[c*CS:(c+1)*CS, :] for active c
3. sparse_fwd: Y[:, c*CS:(c+1)*CS] = X @ W[c*CS:(c+1)*CS, :].T for active c
Benchmark against the Python-loop baseline at various d_model sizes.
"""
import math
import os
import random
import time
import urllib.request
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
try:
import tiktoken
except ImportError:
raise ImportError("pip install tiktoken")
# ═══════════════════════════════════════════════════════════════════
# TRITON KERNELS
# ═══════════════════════════════════════════════════════════════════
# ── Kernel 1: Sparse dW ──────────────────────────────────────────
# For each active chunk c:
# grad_W[c*CS:(c+1)*CS, :] = grad_Y[:, c*CS:(c+1)*CS].T @ X
#
# In terms of shapes:
# grad_Y: (M, d_out), X: (M, d_in), W: (d_out, d_in)
# For chunk c: rows c*CS..(c+1)*CS of W get grad from cols c*CS..(c+1)*CS of grad_Y
#
# Grid: (num_active * ceil(CS/BN), ceil(d_in/BK))
# pid0 encodes (active_chunk_linear_id, N-block within CS)
# pid1 encodes K-block within d_in
@triton.autotune(
configs=[
triton.Config({'BN': 32, 'BK': 64, 'BM': 32}, num_stages=3, num_warps=4),
triton.Config({'BN': 64, 'BK': 64, 'BM': 32}, num_stages=3, num_warps=4),
triton.Config({'BN': 64, 'BK': 128, 'BM': 32}, num_stages=3, num_warps=4),
triton.Config({'BN': 32, 'BK': 128, 'BM': 64}, num_stages=3, num_warps=4),
triton.Config({'BN': 64, 'BK': 64, 'BM': 64}, num_stages=4, num_warps=4),
],
key=['M', 'd_in', 'CS'],
)
@triton.jit
def _sparse_bwd_dW_kernel(
X_ptr, dY_ptr, dW_ptr, chunk_ids_ptr,
M, d_in, d_out, num_active,
stride_xm, stride_xk,
stride_dym, stride_dyn,
stride_dwn, stride_dwk,
CS: tl.constexpr,
BN: tl.constexpr,
BK: tl.constexpr,
BM: tl.constexpr,
):
"""Compute dW tiles for active chunks. Each program writes one [BN, BK] tile."""
pid0 = tl.program_id(0)
pid1 = tl.program_id(1)
N_BLOCKS_PER_CHUNK = tl.cdiv(CS, BN)
chunk_linear_id = pid0 // N_BLOCKS_PER_CHUNK
n_block_id = pid0 % N_BLOCKS_PER_CHUNK
k_block_id = pid1
if chunk_linear_id >= num_active:
return
chunk_idx = tl.load(chunk_ids_ptr + chunk_linear_id)
chunk_start = chunk_idx * CS
# Tile ranges
rn = n_block_id * BN + tl.arange(0, BN) # rows of dW (= cols of chunk in dY)
rk = k_block_id * BK + tl.arange(0, BK) # cols of dW (= cols of X)
n_abs = chunk_start + rn # absolute column indices in dY
n_mask = rn < CS
k_mask = rk < d_in
# Accumulate dY[:, chunk_cols].T @ X[:, k_cols] over M-tiles
acc = tl.zeros((BN, BK), dtype=tl.float32)
for m_start in range(0, M, BM):
rm = m_start + tl.arange(0, BM)
m_mask = rm < M
# Load X tile: (BM, BK)
x = tl.load(
X_ptr + rm[:, None] * stride_xm + rk[None, :] * stride_xk,
mask=m_mask[:, None] & k_mask[None, :],
other=0.0,
)
# Load dY tile: (BM, BN)
dy = tl.load(
dY_ptr + rm[:, None] * stride_dym + n_abs[None, :] * stride_dyn,
mask=m_mask[:, None] & n_mask[None, :],
other=0.0,
)
# dY.T @ X -> (BN, BK)
acc = tl.dot(tl.trans(dy), x, acc=acc)
# Write to dW: row = chunk_start + rn, col = rk
# dW layout: (d_out, d_in)
dw_ptrs = dW_ptr + n_abs[:, None] * stride_dwn + rk[None, :] * stride_dwk
tl.store(dw_ptrs, acc.to(dW_ptr.dtype.element_ty), mask=n_mask[:, None] & k_mask[None, :])
def sparse_bwd_dW(X, dY, active_chunks, chunk_size, d_out):
"""Fused Triton kernel for sparse dW computation."""
M, d_in = X.shape
num_active = active_chunks.shape[0]
CS = chunk_size
dW = torch.zeros(d_out, d_in, device=X.device, dtype=X.dtype)
if num_active == 0:
return dW
chunk_ids = active_chunks.to(torch.int32).contiguous()
grid = lambda META: (
num_active * triton.cdiv(CS, META['BN']),
triton.cdiv(d_in, META['BK']),
)
_sparse_bwd_dW_kernel[grid](
X, dY, dW, chunk_ids,
M, d_in, d_out, num_active,
X.stride(0), X.stride(1),
dY.stride(0), dY.stride(1),
dW.stride(0), dW.stride(1),
CS=CS,
)
return dW
# ── Kernel 2: Sparse dX ──────────────────────────────────────────
# For each active chunk c:
# grad_X += grad_Y[:, c*CS:(c+1)*CS] @ W[c*CS:(c+1)*CS, :]
#
# Grid: (ceil(M/BM), ceil(d_in/BK))
# Each program accumulates contributions from ALL active chunks.
@triton.autotune(
configs=[
triton.Config({'BM': 32, 'BK': 64, 'BN': 32}, num_stages=3, num_warps=4),
triton.Config({'BM': 64, 'BK': 64, 'BN': 32}, num_stages=3, num_warps=4),
triton.Config({'BM': 64, 'BK': 128, 'BN': 64}, num_stages=3, num_warps=4),
triton.Config({'BM': 32, 'BK': 128, 'BN': 32}, num_stages=4, num_warps=4),
],
key=['M', 'd_in', 'CS'],
)
@triton.jit
def _sparse_bwd_dX_kernel(
dY_ptr, W_ptr, dX_ptr, chunk_ids_ptr,
M, d_in, d_out, num_active,
stride_dym, stride_dyn,
stride_wn, stride_wk,
stride_dxm, stride_dxk,
CS: tl.constexpr,
BM: tl.constexpr,
BK: tl.constexpr,
BN: tl.constexpr,
):
"""Compute dX tiles by summing over active chunks."""
pid_m = tl.program_id(0)
pid_k = tl.program_id(1)
rm = pid_m * BM + tl.arange(0, BM)
rk = pid_k * BK + tl.arange(0, BK)
m_mask = rm < M
k_mask = rk < d_in
acc = tl.zeros((BM, BK), dtype=tl.float32)
# Sum over all active chunks
for i in range(num_active):
chunk_idx = tl.load(chunk_ids_ptr + i)
chunk_start = chunk_idx * CS
# Tile over BN within the chunk
for n_start in range(0, CS, BN):
rn = n_start + tl.arange(0, BN)
n_abs = chunk_start + rn
n_mask = rn < CS
# Load dY tile: (BM, BN)
dy = tl.load(
dY_ptr + rm[:, None] * stride_dym + n_abs[None, :] * stride_dyn,
mask=m_mask[:, None] & n_mask[None, :],
other=0.0,
)
# Load W tile: (BN, BK) β€” W[chunk_start+rn, rk]
w = tl.load(
W_ptr + n_abs[:, None] * stride_wn + rk[None, :] * stride_wk,
mask=n_mask[:, None] & k_mask[None, :],
other=0.0,
)
# dY @ W -> (BM, BK)
acc = tl.dot(dy, w, acc=acc)
# Write dX
dx_ptrs = dX_ptr + rm[:, None] * stride_dxm + rk[None, :] * stride_dxk
tl.store(dx_ptrs, acc.to(dX_ptr.dtype.element_ty), mask=m_mask[:, None] & k_mask[None, :])
def sparse_bwd_dX(dY, W, active_chunks, chunk_size, M, d_in):
"""Fused Triton kernel for sparse dX computation."""
num_active = active_chunks.shape[0]
CS = chunk_size
dX = torch.zeros(M, d_in, device=dY.device, dtype=dY.dtype)
if num_active == 0:
return dX
chunk_ids = active_chunks.to(torch.int32).contiguous()
grid = lambda META: (
triton.cdiv(M, META['BM']),
triton.cdiv(d_in, META['BK']),
)
_sparse_bwd_dX_kernel[grid](
dY, W, dX, chunk_ids,
M, d_in, dY.shape[1], num_active,
dY.stride(0), dY.stride(1),
W.stride(0), W.stride(1),
dX.stride(0), dX.stride(1),
CS=CS,
)
return dX
# ── Kernel 3: Sparse dBias ────────────────────────────────────────
# Simple: bias_grad[c*CS:(c+1)*CS] = dY[:, c*CS:(c+1)*CS].sum(dim=0)
@triton.jit
def _sparse_bwd_dbias_kernel(
dY_ptr, dB_ptr, chunk_ids_ptr,
M, d_out, num_active,
stride_dym, stride_dyn,
CS: tl.constexpr,
BM: tl.constexpr,
):
pid = tl.program_id(0) # one per (active_chunk, col_within_chunk)
chunk_linear = pid // CS
col_in_chunk = pid % CS
if chunk_linear >= num_active:
return
chunk_idx = tl.load(chunk_ids_ptr + chunk_linear)
col_abs = chunk_idx * CS + col_in_chunk
acc = 0.0
for m_start in range(0, M, BM):
rm = m_start + tl.arange(0, BM)
m_mask = rm < M
vals = tl.load(dY_ptr + rm * stride_dym + col_abs * stride_dyn, mask=m_mask, other=0.0)
acc += tl.sum(vals)
tl.store(dB_ptr + col_abs, acc.to(dB_ptr.dtype.element_ty))
def sparse_bwd_dbias(dY, active_chunks, chunk_size, d_out):
M = dY.shape[0]
num_active = active_chunks.shape[0]
dB = torch.zeros(d_out, device=dY.device, dtype=dY.dtype)
if num_active == 0:
return dB
chunk_ids = active_chunks.to(torch.int32).contiguous()
BM = 128
grid = (num_active * chunk_size,)
_sparse_bwd_dbias_kernel[grid](
dY, dB, chunk_ids,
M, d_out, num_active,
dY.stride(0), dY.stride(1),
CS=chunk_size, BM=BM,
)
return dB
# ═══════════════════════════════════════════════════════════════════
# AUTOGRAD FUNCTION: Triton-fused
# ═══════════════════════════════════════════════════════════════════
class TritonChunkedSparseLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias, active_chunks, chunk_size, sparse_dx):
ctx.save_for_backward(x, weight, active_chunks)
ctx.has_bias = bias is not None
ctx.sparse_dx = sparse_dx
ctx.chunk_size = chunk_size
return F.linear(x, weight, bias)
@staticmethod
def backward(ctx, grad_y):
x, weight, active_chunks = ctx.saved_tensors
cs = ctx.chunk_size
d_out, d_in = weight.shape
x_flat = x.reshape(-1, d_in)
gy_flat = grad_y.reshape(-1, d_out)
M = x_flat.shape[0]
# grad_W via Triton
grad_w = sparse_bwd_dW(x_flat, gy_flat, active_chunks, cs, d_out)
# grad_bias via Triton
grad_b = sparse_bwd_dbias(gy_flat, active_chunks, cs, d_out) if ctx.has_bias else None
# grad_X
if ctx.sparse_dx:
grad_x_flat = sparse_bwd_dX(gy_flat, weight, active_chunks, cs, M, d_in)
else:
grad_x_flat = gy_flat @ weight # dense dX
return grad_x_flat.reshape(x.shape), grad_w, grad_b, None, None, None
# ═══════════════════════════════════════════════════════════════════
# AUTOGRAD FUNCTION: Python-loop baseline (for comparison)
# ═══════════════════════════════════════════════════════════════════
class PythonLoopSparseLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias, active_chunks, chunk_size, sparse_dx):
ctx.save_for_backward(x, weight, active_chunks)
ctx.has_bias = bias is not None
ctx.sparse_dx = sparse_dx
ctx.chunk_size = chunk_size
return F.linear(x, weight, bias)
@staticmethod
def backward(ctx, grad_y):
x, weight, active_chunks = ctx.saved_tensors
cs = ctx.chunk_size
x_flat = x.reshape(-1, x.shape[-1])
gy_flat = grad_y.reshape(-1, grad_y.shape[-1])
grad_w = torch.zeros_like(weight)
grad_b = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if ctx.has_bias else None
if ctx.sparse_dx:
grad_x_flat = torch.zeros_like(x_flat)
else:
grad_x_flat = gy_flat @ weight
for c in active_chunks.tolist():
s, e = c * cs, (c + 1) * cs
gy_slice = gy_flat[:, s:e]
grad_w[s:e, :] = gy_slice.t() @ x_flat
if ctx.has_bias:
grad_b[s:e] = gy_slice.sum(0)
if ctx.sparse_dx:
grad_x_flat += gy_slice @ weight[s:e, :]
return grad_x_flat.reshape(x.shape), grad_w, grad_b, None, None, None
# ═══════════════════════════════════════════════════════════════════
# CORRECTNESS TEST
# ═══════════════════════════════════════════════════════════════════
def test_correctness():
print("Testing correctness...")
torch.manual_seed(42)
device = "cuda"
for d_in, d_out, cs in [(512, 2048, 64), (1024, 4096, 64), (256, 1024, 32)]:
M = 2048 # B*T
n_chunks = d_out // cs
n_active = max(1, int(0.1 * n_chunks))
active = torch.randperm(n_chunks, device=device)[:n_active].sort().values
x = torch.randn(M, d_in, device=device, requires_grad=False)
w = torch.randn(d_out, d_in, device=device, requires_grad=False)
b = torch.randn(d_out, device=device, requires_grad=False)
gy = torch.randn(M, d_out, device=device, requires_grad=False)
# Reference: Python loop
ref_dw = torch.zeros_like(w)
ref_db = torch.zeros_like(b)
ref_dx = gy @ w # dense dX
for c in active.tolist():
s, e = c * cs, (c + 1) * cs
ref_dw[s:e] = gy[:, s:e].t() @ x
ref_db[s:e] = gy[:, s:e].sum(0)
# Triton
tri_dw = sparse_bwd_dW(x, gy, active, cs, d_out)
tri_db = sparse_bwd_dbias(gy, active, cs, d_out)
tri_dx_sparse = sparse_bwd_dX(gy, w, active, cs, M, d_in)
# Compare
dw_err = (tri_dw - ref_dw).abs().max().item()
db_err = (tri_db - ref_db).abs().max().item()
# For sparse dX, reference
ref_dx_sparse = torch.zeros_like(x)
for c in active.tolist():
s, e = c * cs, (c + 1) * cs
ref_dx_sparse += gy[:, s:e] @ w[s:e]
dx_err = (tri_dx_sparse - ref_dx_sparse).abs().max().item()
status = "βœ“" if dw_err < 1e-2 and db_err < 1e-2 and dx_err < 1e-2 else "βœ—"
print(f" {status} d_in={d_in}, d_out={d_out}, cs={cs}: dW_err={dw_err:.6f}, dB_err={db_err:.6f}, dX_err={dx_err:.6f}")
print()
# ═══════════════════════════════════════════════════════════════════
# BENCHMARK
# ═══════════════════════════════════════════════════════════════════
def benchmark():
print("="*80)
print("BENCHMARK: Triton Fused vs Python Loop vs Dense")
print("="*80)
device = "cuda"
B, T = 8, 256
M = B * T
cs = 64
af = 0.10
warmup_iters = 10
bench_iters = 50
print(f"\nM={M} (B={B}, T={T}), chunk_size={cs}, active_frac={af}")
print(f"{'d_model':>7} | {'d_out':>7} | {'active':>6} | {'Dense':>10} | {'PyLoop':>10} | {'Triton':>10} | {'Tri/Dense':>10} | {'Tri/PyLoop':>10}")
print("-" * 95)
for d_in in [256, 512, 768, 1024, 1536, 2048]:
d_out = 4 * d_in
n_chunks = d_out // cs
n_active = max(1, int(af * n_chunks))
active = torch.randperm(n_chunks, device=device)[:n_active].sort().values
x = torch.randn(M, d_in, device=device)
w = torch.randn(d_out, d_in, device=device)
b = torch.randn(d_out, device=device)
gy = torch.randn(M, d_out, device=device)
# Dense backward (dW + dX + dB)
def dense_bwd():
dw = gy.t() @ x
dx = gy @ w
db = gy.sum(0)
return dw, dx, db
# Python loop backward
def pyloop_bwd():
dw = torch.zeros_like(w)
db = torch.zeros_like(b)
dx = gy @ w # dense dX
for c in active.tolist():
s, e = c * cs, (c + 1) * cs
dw[s:e] = gy[:, s:e].t() @ x
db[s:e] = gy[:, s:e].sum(0)
return dw, dx, db
# Triton fused backward
def triton_bwd():
dw = sparse_bwd_dW(x, gy, active, cs, d_out)
dx = gy @ w # dense dX (same as pyloop)
db = sparse_bwd_dbias(gy, active, cs, d_out)
return dw, dx, db
# Warmup
for _ in range(warmup_iters):
dense_bwd(); pyloop_bwd(); triton_bwd()
torch.cuda.synchronize()
# Bench dense
torch.cuda.synchronize(); t0 = time.perf_counter()
for _ in range(bench_iters): dense_bwd()
torch.cuda.synchronize(); dense_time = (time.perf_counter() - t0) / bench_iters
# Bench pyloop
torch.cuda.synchronize(); t0 = time.perf_counter()
for _ in range(bench_iters): pyloop_bwd()
torch.cuda.synchronize(); pyloop_time = (time.perf_counter() - t0) / bench_iters
# Bench triton
torch.cuda.synchronize(); t0 = time.perf_counter()
for _ in range(bench_iters): triton_bwd()
torch.cuda.synchronize(); triton_time = (time.perf_counter() - t0) / bench_iters
tri_vs_dense = dense_time / triton_time
tri_vs_pyloop = pyloop_time / triton_time
print(f"{d_in:>7} | {d_out:>7} | {n_active:>6} | {dense_time*1000:>9.2f}ms | {pyloop_time*1000:>9.2f}ms | {triton_time*1000:>9.2f}ms | {tri_vs_dense:>9.2f}x | {tri_vs_pyloop:>9.2f}x")
# Also benchmark with sparse_dX (Triton dX kernel)
print(f"\n{'='*80}")
print("With Triton sparse_dX (both dW and dX are sparse):")
print(f"{'d_model':>7} | {'Dense':>10} | {'Triton_all':>10} | {'Speedup':>10}")
print("-" * 50)
for d_in in [512, 1024, 2048]:
d_out = 4 * d_in
n_chunks = d_out // cs
n_active = max(1, int(af * n_chunks))
active = torch.randperm(n_chunks, device=device)[:n_active].sort().values
x = torch.randn(M, d_in, device=device)
w = torch.randn(d_out, d_in, device=device)
gy = torch.randn(M, d_out, device=device)
def dense_full():
dw = gy.t() @ x; dx = gy @ w; return dw, dx
def triton_full():
dw = sparse_bwd_dW(x, gy, active, cs, d_out)
dx = sparse_bwd_dX(gy, w, active, cs, M, d_in)
return dw, dx
for _ in range(warmup_iters): dense_full(); triton_full()
torch.cuda.synchronize()
torch.cuda.synchronize(); t0 = time.perf_counter()
for _ in range(bench_iters): dense_full()
torch.cuda.synchronize(); dt = (time.perf_counter() - t0) / bench_iters
torch.cuda.synchronize(); t0 = time.perf_counter()
for _ in range(bench_iters): triton_full()
torch.cuda.synchronize(); tt = (time.perf_counter() - t0) / bench_iters
print(f"{d_in:>7} | {dt*1000:>9.2f}ms | {tt*1000:>9.2f}ms | {dt/tt:>9.2f}x")
# ═══════════════════════════════════════════════════════════════════
# MAIN
# ═══════════════════════════════════════════════════════════════════
if __name__ == "__main__":
test_correctness()
benchmark()