Mamba WebGPU -- First Browser-Native SSM Inference Engine

Falcon-Mamba 7B running in a browser tab. Pure WebGPU compute shaders. No MLC, no TVM, no WASM, no compilation step. 12 hand-written WGSL shaders. First ever browser-native Mamba/SSM inference.

What This Is

A complete inference runtime for Falcon-Mamba-7B-Instruct that runs entirely in the browser using WebGPU compute shaders. No server-side inference -- the model loads into GPU memory via the browser's WebGPU API and generates text using hand-written WGSL compute shaders.

This is NOT a transformer runtime. This is an SSM (State Space Model) runtime -- the Mamba architecture, which uses persistent recurrent state instead of KV cache. The state is fixed-size (38MB) regardless of context length.

Why This Matters

  • WebLLM ships transformer models to the browser. This ships SSM models.
  • MLC/TVM don't support Mamba architecture (confirmed)
  • The SSM state IS persistent memory -- save it, restore it, the entity remembers
  • Fixed 38MB state vs unbounded KV cache growth
  • No server needed for inference

Quick Start

# Clone this repo
git clone https://huggingface.co/LJTSG/mamba-webgpu

# Start the dev server (serves weights from HF cache via byte-range requests)
node serve_mamba.js

# Open http://localhost:8140
# Click: Initialize -> Load Weights -> Generate

Requirements:

  • Falcon-Mamba-7B-Instruct weights in your HuggingFace cache (~/.cache/huggingface/hub/models--tiiuae--falcon-mamba-7b-instruct/)
  • Node.js (for the dev server)
  • Python + transformers (for tokenization)
  • Chrome/Edge with WebGPU support
  • GPU with >= 16GB accessible via WebGPU (tested on AMD Strix Halo iGPU with 64GB unified memory)

Architecture

Token -> Embedding lookup
  -> 64x Mamba Layer:
      RMSNorm -> in_proj GEMV -> split(x, gate)
      -> conv1d_step (persistent state)
      -> SiLU
      -> x_proj GEMV -> RMSNorm(dt_pre, B, C)  [Falcon-Mamba specific]
      -> dt_proj GEMV -> softplus
      -> SSU (selective state update, persistent state)
      -> SiLU(gate) * hidden_y
      -> out_proj GEMV -> residual add
  -> Final RMSNorm -> lm_head GEMV -> Temperature sampling

12 WGSL Compute Shaders:

Shader Purpose Workgroup
rmsnorm.wgsl Root mean square normalization (with weights) 64 threads
rmsnorm_noweight.wgsl RMSNorm without learned weights (B/C/dt normalization) 64 threads
matmul_gemv.wgsl Matrix-vector product (M=1 specialized) 64 threads/row
conv1d_step.wgsl Autoregressive depthwise conv1d with state cache 64 threads
ssu.wgsl Selective state update (SSM scan, the core Mamba op) 16 threads
silu.wgsl SiLU/Swish activation, in-place 64 threads
softplus.wgsl Softplus activation 64 threads
embedding.wgsl Embedding table lookup -
elementwise_mul.wgsl Element-wise multiply, in-place 64 threads
add_residual.wgsl Residual connection add, in-place 64 threads
sample.wgsl Temperature-based multinomial sampling 256 threads
bf16_to_f32.wgsl BFloat16 to Float32 conversion 64 threads

Performance

Tested on AMD Strix Halo (Radeon 8060S iGPU, RDNA-3, 64GB unified memory):

  • ~3 tok/s (~180ms per token)
  • ~60s weight loading (14GB F32 via byte-range fetch)
  • 38MB persistent SSM state (64 layers x 608KB)
  • ~960 shader dispatches per token (15 ops x 64 layers)

The Build Story

Built over 36 hours across two sessions. Six bugs stood between "all zeros" and coherent output:

  1. Buffer alignment -- WebGPU requires storage buffer binding offsets to be 256-byte aligned. An unaligned offset silently invalidated entire command encoders.
  2. A_log transform -- Falcon-Mamba stores A_log; the SSU needs A = -exp(A_log) for proper state decay.
  3. Storage buffer limit -- The SSU shader uses 9 storage buffers; default WebGPU limit is 8.
  4. Illegal buffer flags -- MAP_READ cannot be combined with STORAGE usage.
  5. Diagnostic overhead -- Per-token GPU readbacks for debugging were causing device timeouts.
  6. Missing RMSNorm on B, C, dt_pre -- Falcon-Mamba applies weightless RMSNorm to B, C, and dt_pre before the SSU. Standard Mamba does not. This was the final bug -- every shader was correct, but we were implementing the wrong model.

The debugging involved systematic golden-value comparison against PyTorch, checking each intermediate buffer across all 8192 elements. Every single shader operation matched to 6 decimal places. The divergence was in the model architecture, not the compute.

Files

  • mamba_runtime.js -- WebGPU device init, shader compilation, safetensors weight loading (byte-range fetch, BF16->F32 CPU conversion), forward pass orchestration, generation loop
  • serve_mamba.js -- Node.js dev server with Range request support for weights, tokenize/detokenize endpoints
  • index.html -- Test page with Initialize/Load/Generate buttons
  • shaders/*.wgsl -- 12 WGSL compute shaders
  • golden_dump.py -- PyTorch golden value dumper for debugging
  • REPORT.md -- Detailed build report

Limitations

  • Single-token decode only (no batch/prefill optimization)
  • F32 weights (no quantization yet -- loads full 14GB)
  • Tokenization requires Python server-side (no in-browser tokenizer)
  • Tested only on AMD RDNA-3 iGPU; other GPUs may need limit adjustments

License

Apache 2.0

Credits

Built by Joshua (@LJTSG) and Claude (Anthropic Opus 4.6). Model: tiiuae/falcon-mamba-7b-instruct. Shaders ported from gfx1151_runtime (Vulkan compute SSM runtime for AMD iGPU).

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for LJTSG/mamba-webgpu

Finetuned
(5)
this model