Instructions to use kernels-community/metal-flash-sdpa with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/metal-flash-sdpa with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/metal-flash-sdpa") - Notebooks
- Google Colab
- Kaggle
Add flash_attn_with_kvcache paged-decode kernel (port from huggingface/transformers#45977)
Add flash_attn_with_kvcache (paged decode with block tables)
Why
transformers looks up flash_attn_with_kvcache on the loaded kernel
module to pick the paged decode path with block tables β seemodeling_flash_attention_utils._lazy_imports lines 183-202. Today this
package exposes only flash_attn_varlen_func, so when a user setsattn_implementation="kernels-community/metal-flash-sdpa":
- single-stream
model.generate()crashes withTypeError: 'NoneType' object is not callable(the standard generate path needsflash_attn_funcorflash_attn_with_kvcacheto handle the (B, 1) decode shape β neither
exists), and - continuous batching falls back to the varlen path even at decode,
losing the block-table fast path (the warning atmodeling_flash_attention_utils.py:197-202calls this out).
What
Ports the paged_decode_attention_f32 Metal kernel proven inhuggingface/transformers#45977 (the GGUF MPS work). One simdgroup per
(request, head) pair runs a Flash-Decoding-style online softmax,
reading K/V from a paged cache via block_table[batch, t / block_size]
β no gather, no contiguous KV materialisation.
Also bundles kv_paged_write_f32 β a single-dispatch kernel that
computes the write slot from block_table + seq_lens and writes the
just-produced K/V into the paged cache. Together they cover the
decode step end-to-end.
Status
This PR is a kernel-only contribution: the .metal source is
in place but the .mm C++ binding, torch_binding.{cpp,h} registration,
and Python flash_attn_with_kvcache wrapper still need to be added to
match the package's conventions. I can take those on in a follow-up
once the kernel approach lands.
Validation (from #45977 on M3 Max)
- Correctness: bit-equivalent to a reference SDPA over a gathered KV
tensor at(B=4, H_Q=16, head_dim=128, S=50), max abs diff 1.67e-6
(fp32 rounding noise). - Microbench: ~134 Β΅s/call at the shape above. Replaces the
~16 ms/layer gather-then-SDPA path in the transformers CB block-table
decode loop on MPS. - End-to-end on
Qwen1.5-MoE-A2.7B Q4_K_M, M3 Max,generate_batchN=8, MAX_NEW=128:- varlen (today's mflash): 121.0 tok/s aggregate
- block-table + paged decode (this PR): 120.6 tok/s with PyTorch
SDPA β the kernel itself is the unlock; expectation is to match
or exceed the varlen path once the.mmbinding lands.
References: paged_decode_attention_f32 source inhuggingface/transformers#45977 βgguf-dequant-kernels/gguf_dequant_metal/dequantize.metal.