Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.
bc1b8eb | #!/usr/bin/env python3 | |
| """ | |
| Triton-fused Chunked Sparse Backward Pass. | |
| Replaces the Python for-loop over active chunks with fused Triton kernels: | |
| 1. sparse_bwd_dW: grad_W[c*CS:(c+1)*CS, :] = grad_Y[:, c*CS:(c+1)*CS].T @ X for active c | |
| 2. sparse_bwd_dX: grad_X += grad_Y[:, c*CS:(c+1)*CS] @ W[c*CS:(c+1)*CS, :] for active c | |
| 3. sparse_fwd: Y[:, c*CS:(c+1)*CS] = X @ W[c*CS:(c+1)*CS, :].T for active c | |
| Benchmark against the Python-loop baseline at various d_model sizes. | |
| """ | |
| import math | |
| import os | |
| import random | |
| import time | |
| import urllib.request | |
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import triton | |
| import triton.language as tl | |
| try: | |
| import tiktoken | |
| except ImportError: | |
| raise ImportError("pip install tiktoken") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TRITON KERNELS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ββ Kernel 1: Sparse dW ββββββββββββββββββββββββββββββββββββββββββ | |
| # For each active chunk c: | |
| # grad_W[c*CS:(c+1)*CS, :] = grad_Y[:, c*CS:(c+1)*CS].T @ X | |
| # | |
| # In terms of shapes: | |
| # grad_Y: (M, d_out), X: (M, d_in), W: (d_out, d_in) | |
| # For chunk c: rows c*CS..(c+1)*CS of W get grad from cols c*CS..(c+1)*CS of grad_Y | |
| # | |
| # Grid: (num_active * ceil(CS/BN), ceil(d_in/BK)) | |
| # pid0 encodes (active_chunk_linear_id, N-block within CS) | |
| # pid1 encodes K-block within d_in | |
| def _sparse_bwd_dW_kernel( | |
| X_ptr, dY_ptr, dW_ptr, chunk_ids_ptr, | |
| M, d_in, d_out, num_active, | |
| stride_xm, stride_xk, | |
| stride_dym, stride_dyn, | |
| stride_dwn, stride_dwk, | |
| CS: tl.constexpr, | |
| BN: tl.constexpr, | |
| BK: tl.constexpr, | |
| BM: tl.constexpr, | |
| ): | |
| """Compute dW tiles for active chunks. Each program writes one [BN, BK] tile.""" | |
| pid0 = tl.program_id(0) | |
| pid1 = tl.program_id(1) | |
| N_BLOCKS_PER_CHUNK = tl.cdiv(CS, BN) | |
| chunk_linear_id = pid0 // N_BLOCKS_PER_CHUNK | |
| n_block_id = pid0 % N_BLOCKS_PER_CHUNK | |
| k_block_id = pid1 | |
| if chunk_linear_id >= num_active: | |
| return | |
| chunk_idx = tl.load(chunk_ids_ptr + chunk_linear_id) | |
| chunk_start = chunk_idx * CS | |
| # Tile ranges | |
| rn = n_block_id * BN + tl.arange(0, BN) # rows of dW (= cols of chunk in dY) | |
| rk = k_block_id * BK + tl.arange(0, BK) # cols of dW (= cols of X) | |
| n_abs = chunk_start + rn # absolute column indices in dY | |
| n_mask = rn < CS | |
| k_mask = rk < d_in | |
| # Accumulate dY[:, chunk_cols].T @ X[:, k_cols] over M-tiles | |
| acc = tl.zeros((BN, BK), dtype=tl.float32) | |
| for m_start in range(0, M, BM): | |
| rm = m_start + tl.arange(0, BM) | |
| m_mask = rm < M | |
| # Load X tile: (BM, BK) | |
| x = tl.load( | |
| X_ptr + rm[:, None] * stride_xm + rk[None, :] * stride_xk, | |
| mask=m_mask[:, None] & k_mask[None, :], | |
| other=0.0, | |
| ) | |
| # Load dY tile: (BM, BN) | |
| dy = tl.load( | |
| dY_ptr + rm[:, None] * stride_dym + n_abs[None, :] * stride_dyn, | |
| mask=m_mask[:, None] & n_mask[None, :], | |
| other=0.0, | |
| ) | |
| # dY.T @ X -> (BN, BK) | |
| acc = tl.dot(tl.trans(dy), x, acc=acc) | |
| # Write to dW: row = chunk_start + rn, col = rk | |
| # dW layout: (d_out, d_in) | |
| dw_ptrs = dW_ptr + n_abs[:, None] * stride_dwn + rk[None, :] * stride_dwk | |
| tl.store(dw_ptrs, acc.to(dW_ptr.dtype.element_ty), mask=n_mask[:, None] & k_mask[None, :]) | |
| def sparse_bwd_dW(X, dY, active_chunks, chunk_size, d_out): | |
| """Fused Triton kernel for sparse dW computation.""" | |
| M, d_in = X.shape | |
| num_active = active_chunks.shape[0] | |
| CS = chunk_size | |
| dW = torch.zeros(d_out, d_in, device=X.device, dtype=X.dtype) | |
| if num_active == 0: | |
| return dW | |
| chunk_ids = active_chunks.to(torch.int32).contiguous() | |
| grid = lambda META: ( | |
| num_active * triton.cdiv(CS, META['BN']), | |
| triton.cdiv(d_in, META['BK']), | |
| ) | |
| _sparse_bwd_dW_kernel[grid]( | |
| X, dY, dW, chunk_ids, | |
| M, d_in, d_out, num_active, | |
| X.stride(0), X.stride(1), | |
| dY.stride(0), dY.stride(1), | |
| dW.stride(0), dW.stride(1), | |
| CS=CS, | |
| ) | |
| return dW | |
| # ββ Kernel 2: Sparse dX ββββββββββββββββββββββββββββββββββββββββββ | |
| # For each active chunk c: | |
| # grad_X += grad_Y[:, c*CS:(c+1)*CS] @ W[c*CS:(c+1)*CS, :] | |
| # | |
| # Grid: (ceil(M/BM), ceil(d_in/BK)) | |
| # Each program accumulates contributions from ALL active chunks. | |
| def _sparse_bwd_dX_kernel( | |
| dY_ptr, W_ptr, dX_ptr, chunk_ids_ptr, | |
| M, d_in, d_out, num_active, | |
| stride_dym, stride_dyn, | |
| stride_wn, stride_wk, | |
| stride_dxm, stride_dxk, | |
| CS: tl.constexpr, | |
| BM: tl.constexpr, | |
| BK: tl.constexpr, | |
| BN: tl.constexpr, | |
| ): | |
| """Compute dX tiles by summing over active chunks.""" | |
| pid_m = tl.program_id(0) | |
| pid_k = tl.program_id(1) | |
| rm = pid_m * BM + tl.arange(0, BM) | |
| rk = pid_k * BK + tl.arange(0, BK) | |
| m_mask = rm < M | |
| k_mask = rk < d_in | |
| acc = tl.zeros((BM, BK), dtype=tl.float32) | |
| # Sum over all active chunks | |
| for i in range(num_active): | |
| chunk_idx = tl.load(chunk_ids_ptr + i) | |
| chunk_start = chunk_idx * CS | |
| # Tile over BN within the chunk | |
| for n_start in range(0, CS, BN): | |
| rn = n_start + tl.arange(0, BN) | |
| n_abs = chunk_start + rn | |
| n_mask = rn < CS | |
| # Load dY tile: (BM, BN) | |
| dy = tl.load( | |
| dY_ptr + rm[:, None] * stride_dym + n_abs[None, :] * stride_dyn, | |
| mask=m_mask[:, None] & n_mask[None, :], | |
| other=0.0, | |
| ) | |
| # Load W tile: (BN, BK) β W[chunk_start+rn, rk] | |
| w = tl.load( | |
| W_ptr + n_abs[:, None] * stride_wn + rk[None, :] * stride_wk, | |
| mask=n_mask[:, None] & k_mask[None, :], | |
| other=0.0, | |
| ) | |
| # dY @ W -> (BM, BK) | |
| acc = tl.dot(dy, w, acc=acc) | |
| # Write dX | |
| dx_ptrs = dX_ptr + rm[:, None] * stride_dxm + rk[None, :] * stride_dxk | |
| tl.store(dx_ptrs, acc.to(dX_ptr.dtype.element_ty), mask=m_mask[:, None] & k_mask[None, :]) | |
| def sparse_bwd_dX(dY, W, active_chunks, chunk_size, M, d_in): | |
| """Fused Triton kernel for sparse dX computation.""" | |
| num_active = active_chunks.shape[0] | |
| CS = chunk_size | |
| dX = torch.zeros(M, d_in, device=dY.device, dtype=dY.dtype) | |
| if num_active == 0: | |
| return dX | |
| chunk_ids = active_chunks.to(torch.int32).contiguous() | |
| grid = lambda META: ( | |
| triton.cdiv(M, META['BM']), | |
| triton.cdiv(d_in, META['BK']), | |
| ) | |
| _sparse_bwd_dX_kernel[grid]( | |
| dY, W, dX, chunk_ids, | |
| M, d_in, dY.shape[1], num_active, | |
| dY.stride(0), dY.stride(1), | |
| W.stride(0), W.stride(1), | |
| dX.stride(0), dX.stride(1), | |
| CS=CS, | |
| ) | |
| return dX | |
| # ββ Kernel 3: Sparse dBias ββββββββββββββββββββββββββββββββββββββββ | |
| # Simple: bias_grad[c*CS:(c+1)*CS] = dY[:, c*CS:(c+1)*CS].sum(dim=0) | |
| def _sparse_bwd_dbias_kernel( | |
| dY_ptr, dB_ptr, chunk_ids_ptr, | |
| M, d_out, num_active, | |
| stride_dym, stride_dyn, | |
| CS: tl.constexpr, | |
| BM: tl.constexpr, | |
| ): | |
| pid = tl.program_id(0) # one per (active_chunk, col_within_chunk) | |
| chunk_linear = pid // CS | |
| col_in_chunk = pid % CS | |
| if chunk_linear >= num_active: | |
| return | |
| chunk_idx = tl.load(chunk_ids_ptr + chunk_linear) | |
| col_abs = chunk_idx * CS + col_in_chunk | |
| acc = 0.0 | |
| for m_start in range(0, M, BM): | |
| rm = m_start + tl.arange(0, BM) | |
| m_mask = rm < M | |
| vals = tl.load(dY_ptr + rm * stride_dym + col_abs * stride_dyn, mask=m_mask, other=0.0) | |
| acc += tl.sum(vals) | |
| tl.store(dB_ptr + col_abs, acc.to(dB_ptr.dtype.element_ty)) | |
| def sparse_bwd_dbias(dY, active_chunks, chunk_size, d_out): | |
| M = dY.shape[0] | |
| num_active = active_chunks.shape[0] | |
| dB = torch.zeros(d_out, device=dY.device, dtype=dY.dtype) | |
| if num_active == 0: | |
| return dB | |
| chunk_ids = active_chunks.to(torch.int32).contiguous() | |
| BM = 128 | |
| grid = (num_active * chunk_size,) | |
| _sparse_bwd_dbias_kernel[grid]( | |
| dY, dB, chunk_ids, | |
| M, d_out, num_active, | |
| dY.stride(0), dY.stride(1), | |
| CS=chunk_size, BM=BM, | |
| ) | |
| return dB | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # AUTOGRAD FUNCTION: Triton-fused | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TritonChunkedSparseLinear(torch.autograd.Function): | |
| def forward(ctx, x, weight, bias, active_chunks, chunk_size, sparse_dx): | |
| ctx.save_for_backward(x, weight, active_chunks) | |
| ctx.has_bias = bias is not None | |
| ctx.sparse_dx = sparse_dx | |
| ctx.chunk_size = chunk_size | |
| return F.linear(x, weight, bias) | |
| def backward(ctx, grad_y): | |
| x, weight, active_chunks = ctx.saved_tensors | |
| cs = ctx.chunk_size | |
| d_out, d_in = weight.shape | |
| x_flat = x.reshape(-1, d_in) | |
| gy_flat = grad_y.reshape(-1, d_out) | |
| M = x_flat.shape[0] | |
| # grad_W via Triton | |
| grad_w = sparse_bwd_dW(x_flat, gy_flat, active_chunks, cs, d_out) | |
| # grad_bias via Triton | |
| grad_b = sparse_bwd_dbias(gy_flat, active_chunks, cs, d_out) if ctx.has_bias else None | |
| # grad_X | |
| if ctx.sparse_dx: | |
| grad_x_flat = sparse_bwd_dX(gy_flat, weight, active_chunks, cs, M, d_in) | |
| else: | |
| grad_x_flat = gy_flat @ weight # dense dX | |
| return grad_x_flat.reshape(x.shape), grad_w, grad_b, None, None, None | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # AUTOGRAD FUNCTION: Python-loop baseline (for comparison) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class PythonLoopSparseLinear(torch.autograd.Function): | |
| def forward(ctx, x, weight, bias, active_chunks, chunk_size, sparse_dx): | |
| ctx.save_for_backward(x, weight, active_chunks) | |
| ctx.has_bias = bias is not None | |
| ctx.sparse_dx = sparse_dx | |
| ctx.chunk_size = chunk_size | |
| return F.linear(x, weight, bias) | |
| def backward(ctx, grad_y): | |
| x, weight, active_chunks = ctx.saved_tensors | |
| cs = ctx.chunk_size | |
| x_flat = x.reshape(-1, x.shape[-1]) | |
| gy_flat = grad_y.reshape(-1, grad_y.shape[-1]) | |
| grad_w = torch.zeros_like(weight) | |
| grad_b = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if ctx.has_bias else None | |
| if ctx.sparse_dx: | |
| grad_x_flat = torch.zeros_like(x_flat) | |
| else: | |
| grad_x_flat = gy_flat @ weight | |
| for c in active_chunks.tolist(): | |
| s, e = c * cs, (c + 1) * cs | |
| gy_slice = gy_flat[:, s:e] | |
| grad_w[s:e, :] = gy_slice.t() @ x_flat | |
| if ctx.has_bias: | |
| grad_b[s:e] = gy_slice.sum(0) | |
| if ctx.sparse_dx: | |
| grad_x_flat += gy_slice @ weight[s:e, :] | |
| return grad_x_flat.reshape(x.shape), grad_w, grad_b, None, None, None | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CORRECTNESS TEST | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_correctness(): | |
| print("Testing correctness...") | |
| torch.manual_seed(42) | |
| device = "cuda" | |
| for d_in, d_out, cs in [(512, 2048, 64), (1024, 4096, 64), (256, 1024, 32)]: | |
| M = 2048 # B*T | |
| n_chunks = d_out // cs | |
| n_active = max(1, int(0.1 * n_chunks)) | |
| active = torch.randperm(n_chunks, device=device)[:n_active].sort().values | |
| x = torch.randn(M, d_in, device=device, requires_grad=False) | |
| w = torch.randn(d_out, d_in, device=device, requires_grad=False) | |
| b = torch.randn(d_out, device=device, requires_grad=False) | |
| gy = torch.randn(M, d_out, device=device, requires_grad=False) | |
| # Reference: Python loop | |
| ref_dw = torch.zeros_like(w) | |
| ref_db = torch.zeros_like(b) | |
| ref_dx = gy @ w # dense dX | |
| for c in active.tolist(): | |
| s, e = c * cs, (c + 1) * cs | |
| ref_dw[s:e] = gy[:, s:e].t() @ x | |
| ref_db[s:e] = gy[:, s:e].sum(0) | |
| # Triton | |
| tri_dw = sparse_bwd_dW(x, gy, active, cs, d_out) | |
| tri_db = sparse_bwd_dbias(gy, active, cs, d_out) | |
| tri_dx_sparse = sparse_bwd_dX(gy, w, active, cs, M, d_in) | |
| # Compare | |
| dw_err = (tri_dw - ref_dw).abs().max().item() | |
| db_err = (tri_db - ref_db).abs().max().item() | |
| # For sparse dX, reference | |
| ref_dx_sparse = torch.zeros_like(x) | |
| for c in active.tolist(): | |
| s, e = c * cs, (c + 1) * cs | |
| ref_dx_sparse += gy[:, s:e] @ w[s:e] | |
| dx_err = (tri_dx_sparse - ref_dx_sparse).abs().max().item() | |
| status = "β" if dw_err < 1e-2 and db_err < 1e-2 and dx_err < 1e-2 else "β" | |
| print(f" {status} d_in={d_in}, d_out={d_out}, cs={cs}: dW_err={dw_err:.6f}, dB_err={db_err:.6f}, dX_err={dx_err:.6f}") | |
| print() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # BENCHMARK | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def benchmark(): | |
| print("="*80) | |
| print("BENCHMARK: Triton Fused vs Python Loop vs Dense") | |
| print("="*80) | |
| device = "cuda" | |
| B, T = 8, 256 | |
| M = B * T | |
| cs = 64 | |
| af = 0.10 | |
| warmup_iters = 10 | |
| bench_iters = 50 | |
| print(f"\nM={M} (B={B}, T={T}), chunk_size={cs}, active_frac={af}") | |
| print(f"{'d_model':>7} | {'d_out':>7} | {'active':>6} | {'Dense':>10} | {'PyLoop':>10} | {'Triton':>10} | {'Tri/Dense':>10} | {'Tri/PyLoop':>10}") | |
| print("-" * 95) | |
| for d_in in [256, 512, 768, 1024, 1536, 2048]: | |
| d_out = 4 * d_in | |
| n_chunks = d_out // cs | |
| n_active = max(1, int(af * n_chunks)) | |
| active = torch.randperm(n_chunks, device=device)[:n_active].sort().values | |
| x = torch.randn(M, d_in, device=device) | |
| w = torch.randn(d_out, d_in, device=device) | |
| b = torch.randn(d_out, device=device) | |
| gy = torch.randn(M, d_out, device=device) | |
| # Dense backward (dW + dX + dB) | |
| def dense_bwd(): | |
| dw = gy.t() @ x | |
| dx = gy @ w | |
| db = gy.sum(0) | |
| return dw, dx, db | |
| # Python loop backward | |
| def pyloop_bwd(): | |
| dw = torch.zeros_like(w) | |
| db = torch.zeros_like(b) | |
| dx = gy @ w # dense dX | |
| for c in active.tolist(): | |
| s, e = c * cs, (c + 1) * cs | |
| dw[s:e] = gy[:, s:e].t() @ x | |
| db[s:e] = gy[:, s:e].sum(0) | |
| return dw, dx, db | |
| # Triton fused backward | |
| def triton_bwd(): | |
| dw = sparse_bwd_dW(x, gy, active, cs, d_out) | |
| dx = gy @ w # dense dX (same as pyloop) | |
| db = sparse_bwd_dbias(gy, active, cs, d_out) | |
| return dw, dx, db | |
| # Warmup | |
| for _ in range(warmup_iters): | |
| dense_bwd(); pyloop_bwd(); triton_bwd() | |
| torch.cuda.synchronize() | |
| # Bench dense | |
| torch.cuda.synchronize(); t0 = time.perf_counter() | |
| for _ in range(bench_iters): dense_bwd() | |
| torch.cuda.synchronize(); dense_time = (time.perf_counter() - t0) / bench_iters | |
| # Bench pyloop | |
| torch.cuda.synchronize(); t0 = time.perf_counter() | |
| for _ in range(bench_iters): pyloop_bwd() | |
| torch.cuda.synchronize(); pyloop_time = (time.perf_counter() - t0) / bench_iters | |
| # Bench triton | |
| torch.cuda.synchronize(); t0 = time.perf_counter() | |
| for _ in range(bench_iters): triton_bwd() | |
| torch.cuda.synchronize(); triton_time = (time.perf_counter() - t0) / bench_iters | |
| tri_vs_dense = dense_time / triton_time | |
| tri_vs_pyloop = pyloop_time / triton_time | |
| print(f"{d_in:>7} | {d_out:>7} | {n_active:>6} | {dense_time*1000:>9.2f}ms | {pyloop_time*1000:>9.2f}ms | {triton_time*1000:>9.2f}ms | {tri_vs_dense:>9.2f}x | {tri_vs_pyloop:>9.2f}x") | |
| # Also benchmark with sparse_dX (Triton dX kernel) | |
| print(f"\n{'='*80}") | |
| print("With Triton sparse_dX (both dW and dX are sparse):") | |
| print(f"{'d_model':>7} | {'Dense':>10} | {'Triton_all':>10} | {'Speedup':>10}") | |
| print("-" * 50) | |
| for d_in in [512, 1024, 2048]: | |
| d_out = 4 * d_in | |
| n_chunks = d_out // cs | |
| n_active = max(1, int(af * n_chunks)) | |
| active = torch.randperm(n_chunks, device=device)[:n_active].sort().values | |
| x = torch.randn(M, d_in, device=device) | |
| w = torch.randn(d_out, d_in, device=device) | |
| gy = torch.randn(M, d_out, device=device) | |
| def dense_full(): | |
| dw = gy.t() @ x; dx = gy @ w; return dw, dx | |
| def triton_full(): | |
| dw = sparse_bwd_dW(x, gy, active, cs, d_out) | |
| dx = sparse_bwd_dX(gy, w, active, cs, M, d_in) | |
| return dw, dx | |
| for _ in range(warmup_iters): dense_full(); triton_full() | |
| torch.cuda.synchronize() | |
| torch.cuda.synchronize(); t0 = time.perf_counter() | |
| for _ in range(bench_iters): dense_full() | |
| torch.cuda.synchronize(); dt = (time.perf_counter() - t0) / bench_iters | |
| torch.cuda.synchronize(); t0 = time.perf_counter() | |
| for _ in range(bench_iters): triton_full() | |
| torch.cuda.synchronize(); tt = (time.perf_counter() - t0) / bench_iters | |
| print(f"{d_in:>7} | {dt*1000:>9.2f}ms | {tt*1000:>9.2f}ms | {dt/tt:>9.2f}x") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MAIN | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| test_correctness() | |
| benchmark() | |