Mamba-2 SSD NKI Kernel for AWS Neuron

NeuronCore-optimized NKI kernel implementing the Mamba-2 Structured State-space Duality (SSD) algorithm. Replaces the Mamba2Mixer module in transformers with a hardware-accelerated version that performs chunk-wise parallel scan on NeuronCore Tensor and Vector engines.

Why This Kernel Exists

On Neuron hardware, the standard Mamba2Mixer falls back to a naive O(L^2) PyTorch implementation (torch_forward()) since the Triton-based mamba_chunk_scan_combined kernel is CUDA-only. This NKI kernel provides the efficient O(L) chunk-parallel algorithm directly on NeuronCore.

Performance (Granite 4.0 Dimensions)

Profiled on trn2.3xlarge (SDK 2.30, LNC=2):

Metric Baseline (nkilib) Optimized (this kernel) Improvement
Total time 452 us 415 us 8.2% faster
VectorE active 352 us 316 us 10.2% reduction
Transpose FLOPs 470 MFLOP 336 MFLOP 28.6% reduction
VectorE instructions 2377 2064 13.2% fewer

Dimensions: batch=1, nheads=32, seqlen=256, headdim=64, dstate=128, chunk_size=128

Profile-Guided Optimizations

The baseline nkilib SSD kernel was VectorE-bottlenecked (61% of execution time). Optimizations:

  1. C pre-transposed: Accept C as (batch, dstate, seqlen) instead of (batch, seqlen, dstate), eliminating 1 nc_transpose per chunk
  2. Shared broadcast: Single exp_cs_last broadcast buffer reused for both Q and dstate dimensions (saves 4 stream_shuffle instructions per chunk when dstate == chunk_size)

Supported Models

Any Mamba-2 model using Mamba2Mixer in transformers:

Model Tested Notes
IBM Granite 4.0 Primary target nheads=128 (TP to 32), headdim=64, dstate=128
state-spaces/mamba2-* Supported Standard Mamba-2 configs

Requirements:

  • chunk_size == 128 (hardcoded for NeuronCore tile alignment)
  • dstate <= 128 (fits in single tile)
  • headdim <= 128 (fits in single tile)

Requirements

  • AWS Neuron SDK 2.30+ (NKI 0.4.0+)
  • transformers >= 4.45 with Mamba2Mixer support
  • kernels library (HuggingFace kernel loading)

Usage

from transformers import AutoModel, KernelConfig

model_id = "ibm-granite/granite-4.0-tiny"  # Any Mamba-2 model

kernel_config = KernelConfig({
    "Mamba2Mixer": "jburtoft/mamba2-ssd-neuron-kernels:NeuronMamba2MixerForward",
})

model = AutoModel.from_pretrained(
    model_id,
    kernel_config=kernel_config,
    device_map="neuron",
)

Architecture

NeuronMamba2MixerLayout     - Weight structure (matches Mamba2Mixer)
NeuronMamba2MixerForward    - Forward pass:
  1. in_proj                 - Input projections (PyTorch)
  2. conv1d                  - Causal convolution (PyTorch)
  3. ssd_kernel              - @nki.jit SSD scan (NKI on NeuronCore)
  4. norm + gate + out_proj  - Output processing (PyTorch)
ssd_kernel                   - Optimized NKI kernel:
  - Chunk-wise parallel scan (O(L) complexity)
  - Intra-chunk: TensorE nc_matmul for C*B^T attention matrix
  - Inter-chunk: VectorE sequential state propagation
  - Pre-transposed C eliminates runtime nc_transpose

Algorithm

The SSD (Structured State-space Duality) decomposes the Mamba-2 recurrence into:

  1. Intra-chunk parallel: Within each 128-token chunk, compute Y = (C*B^T * mask) @ (dt*x) using Tensor Engine matmuls
  2. Inter-chunk sequential: Propagate SSM state between chunks with exponential decay
  3. State-to-output: Apply accumulated state contribution via C^T @ state matmul

This achieves O(L) complexity (linear in sequence length) while maximizing Tensor Engine utilization for the matrix multiplications within each chunk.

License

Apache 2.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support