Mamba-2 SSD NKI Kernel for AWS Neuron
NeuronCore-optimized NKI kernel implementing the Mamba-2 Structured State-space Duality (SSD) algorithm. Replaces the Mamba2Mixer module in transformers with a hardware-accelerated version that performs chunk-wise parallel scan on NeuronCore Tensor and Vector engines.
Why This Kernel Exists
On Neuron hardware, the standard Mamba2Mixer falls back to a naive O(L^2) PyTorch implementation (torch_forward()) since the Triton-based mamba_chunk_scan_combined kernel is CUDA-only. This NKI kernel provides the efficient O(L) chunk-parallel algorithm directly on NeuronCore.
Performance (Granite 4.0 Dimensions)
Profiled on trn2.3xlarge (SDK 2.30, LNC=2):
| Metric | Baseline (nkilib) | Optimized (this kernel) | Improvement |
|---|---|---|---|
| Total time | 452 us | 415 us | 8.2% faster |
| VectorE active | 352 us | 316 us | 10.2% reduction |
| Transpose FLOPs | 470 MFLOP | 336 MFLOP | 28.6% reduction |
| VectorE instructions | 2377 | 2064 | 13.2% fewer |
Dimensions: batch=1, nheads=32, seqlen=256, headdim=64, dstate=128, chunk_size=128
Profile-Guided Optimizations
The baseline nkilib SSD kernel was VectorE-bottlenecked (61% of execution time). Optimizations:
- C pre-transposed: Accept C as
(batch, dstate, seqlen)instead of(batch, seqlen, dstate), eliminating 1nc_transposeper chunk - Shared broadcast: Single
exp_cs_lastbroadcast buffer reused for both Q and dstate dimensions (saves 4stream_shuffleinstructions per chunk when dstate == chunk_size)
Supported Models
Any Mamba-2 model using Mamba2Mixer in transformers:
| Model | Tested | Notes |
|---|---|---|
| IBM Granite 4.0 | Primary target | nheads=128 (TP to 32), headdim=64, dstate=128 |
| state-spaces/mamba2-* | Supported | Standard Mamba-2 configs |
Requirements:
chunk_size == 128(hardcoded for NeuronCore tile alignment)dstate <= 128(fits in single tile)headdim <= 128(fits in single tile)
Requirements
- AWS Neuron SDK 2.30+ (NKI 0.4.0+)
transformers >= 4.45with Mamba2Mixer supportkernelslibrary (HuggingFace kernel loading)
Usage
from transformers import AutoModel, KernelConfig
model_id = "ibm-granite/granite-4.0-tiny" # Any Mamba-2 model
kernel_config = KernelConfig({
"Mamba2Mixer": "jburtoft/mamba2-ssd-neuron-kernels:NeuronMamba2MixerForward",
})
model = AutoModel.from_pretrained(
model_id,
kernel_config=kernel_config,
device_map="neuron",
)
Architecture
NeuronMamba2MixerLayout - Weight structure (matches Mamba2Mixer)
NeuronMamba2MixerForward - Forward pass:
1. in_proj - Input projections (PyTorch)
2. conv1d - Causal convolution (PyTorch)
3. ssd_kernel - @nki.jit SSD scan (NKI on NeuronCore)
4. norm + gate + out_proj - Output processing (PyTorch)
ssd_kernel - Optimized NKI kernel:
- Chunk-wise parallel scan (O(L) complexity)
- Intra-chunk: TensorE nc_matmul for C*B^T attention matrix
- Inter-chunk: VectorE sequential state propagation
- Pre-transposed C eliminates runtime nc_transpose
Algorithm
The SSD (Structured State-space Duality) decomposes the Mamba-2 recurrence into:
- Intra-chunk parallel: Within each 128-token chunk, compute
Y = (C*B^T * mask) @ (dt*x)using Tensor Engine matmuls - Inter-chunk sequential: Propagate SSM state between chunks with exponential decay
- State-to-output: Apply accumulated state contribution via
C^T @ statematmul
This achieves O(L) complexity (linear in sequence length) while maximizing Tensor Engine utilization for the matrix multiplications within each chunk.
License
Apache 2.0