Mamba WebGPU -- First Browser-Native SSM Inference Engine
Falcon-Mamba 7B running in a browser tab. Pure WebGPU compute shaders. No MLC, no TVM, no WASM, no compilation step. 12 hand-written WGSL shaders. First ever browser-native Mamba/SSM inference.
What This Is
A complete inference runtime for Falcon-Mamba-7B-Instruct that runs entirely in the browser using WebGPU compute shaders. No server-side inference -- the model loads into GPU memory via the browser's WebGPU API and generates text using hand-written WGSL compute shaders.
This is NOT a transformer runtime. This is an SSM (State Space Model) runtime -- the Mamba architecture, which uses persistent recurrent state instead of KV cache. The state is fixed-size (38MB) regardless of context length.
Why This Matters
- WebLLM ships transformer models to the browser. This ships SSM models.
- MLC/TVM don't support Mamba architecture (confirmed)
- The SSM state IS persistent memory -- save it, restore it, the entity remembers
- Fixed 38MB state vs unbounded KV cache growth
- No server needed for inference
Quick Start
# Clone this repo
git clone https://huggingface.co/LJTSG/mamba-webgpu
# Start the dev server (serves weights from HF cache via byte-range requests)
node serve_mamba.js
# Open http://localhost:8140
# Click: Initialize -> Load Weights -> Generate
Requirements:
- Falcon-Mamba-7B-Instruct weights in your HuggingFace cache (
~/.cache/huggingface/hub/models--tiiuae--falcon-mamba-7b-instruct/) - Node.js (for the dev server)
- Python +
transformers(for tokenization) - Chrome/Edge with WebGPU support
- GPU with >= 16GB accessible via WebGPU (tested on AMD Strix Halo iGPU with 64GB unified memory)
Architecture
Token -> Embedding lookup
-> 64x Mamba Layer:
RMSNorm -> in_proj GEMV -> split(x, gate)
-> conv1d_step (persistent state)
-> SiLU
-> x_proj GEMV -> RMSNorm(dt_pre, B, C) [Falcon-Mamba specific]
-> dt_proj GEMV -> softplus
-> SSU (selective state update, persistent state)
-> SiLU(gate) * hidden_y
-> out_proj GEMV -> residual add
-> Final RMSNorm -> lm_head GEMV -> Temperature sampling
12 WGSL Compute Shaders:
| Shader | Purpose | Workgroup |
|---|---|---|
rmsnorm.wgsl |
Root mean square normalization (with weights) | 64 threads |
rmsnorm_noweight.wgsl |
RMSNorm without learned weights (B/C/dt normalization) | 64 threads |
matmul_gemv.wgsl |
Matrix-vector product (M=1 specialized) | 64 threads/row |
conv1d_step.wgsl |
Autoregressive depthwise conv1d with state cache | 64 threads |
ssu.wgsl |
Selective state update (SSM scan, the core Mamba op) | 16 threads |
silu.wgsl |
SiLU/Swish activation, in-place | 64 threads |
softplus.wgsl |
Softplus activation | 64 threads |
embedding.wgsl |
Embedding table lookup | - |
elementwise_mul.wgsl |
Element-wise multiply, in-place | 64 threads |
add_residual.wgsl |
Residual connection add, in-place | 64 threads |
sample.wgsl |
Temperature-based multinomial sampling | 256 threads |
bf16_to_f32.wgsl |
BFloat16 to Float32 conversion | 64 threads |
Performance
Tested on AMD Strix Halo (Radeon 8060S iGPU, RDNA-3, 64GB unified memory):
- ~3 tok/s (~180ms per token)
- ~60s weight loading (14GB F32 via byte-range fetch)
- 38MB persistent SSM state (64 layers x 608KB)
- ~960 shader dispatches per token (15 ops x 64 layers)
The Build Story
Built over 36 hours across two sessions. Six bugs stood between "all zeros" and coherent output:
- Buffer alignment -- WebGPU requires storage buffer binding offsets to be 256-byte aligned. An unaligned offset silently invalidated entire command encoders.
- A_log transform -- Falcon-Mamba stores A_log; the SSU needs A = -exp(A_log) for proper state decay.
- Storage buffer limit -- The SSU shader uses 9 storage buffers; default WebGPU limit is 8.
- Illegal buffer flags -- MAP_READ cannot be combined with STORAGE usage.
- Diagnostic overhead -- Per-token GPU readbacks for debugging were causing device timeouts.
- Missing RMSNorm on B, C, dt_pre -- Falcon-Mamba applies weightless RMSNorm to B, C, and dt_pre before the SSU. Standard Mamba does not. This was the final bug -- every shader was correct, but we were implementing the wrong model.
The debugging involved systematic golden-value comparison against PyTorch, checking each intermediate buffer across all 8192 elements. Every single shader operation matched to 6 decimal places. The divergence was in the model architecture, not the compute.
Files
mamba_runtime.js-- WebGPU device init, shader compilation, safetensors weight loading (byte-range fetch, BF16->F32 CPU conversion), forward pass orchestration, generation loopserve_mamba.js-- Node.js dev server with Range request support for weights, tokenize/detokenize endpointsindex.html-- Test page with Initialize/Load/Generate buttonsshaders/*.wgsl-- 12 WGSL compute shadersgolden_dump.py-- PyTorch golden value dumper for debuggingREPORT.md-- Detailed build report
Limitations
- Single-token decode only (no batch/prefill optimization)
- F32 weights (no quantization yet -- loads full 14GB)
- Tokenization requires Python server-side (no in-browser tokenizer)
- Tested only on AMD RDNA-3 iGPU; other GPUs may need limit adjustments
License
Apache 2.0
Credits
Built by Joshua (@LJTSG) and Claude (Anthropic Opus 4.6). Model: tiiuae/falcon-mamba-7b-instruct. Shaders ported from gfx1151_runtime (Vulkan compute SSM runtime for AMD iGPU).
Model tree for LJTSG/mamba-webgpu
Base model
tiiuae/falcon-mamba-7b