Kernels

Add flash_attn_with_kvcache paged-decode kernel (port from huggingface/transformers#45977)

#4
by ArthurZ HF Staff - opened

Add flash_attn_with_kvcache (paged decode with block tables)

Why

transformers looks up flash_attn_with_kvcache on the loaded kernel
module to pick the paged decode path with block tables β€” see
modeling_flash_attention_utils._lazy_imports lines 183-202. Today this
package exposes only flash_attn_varlen_func, so when a user sets
attn_implementation="kernels-community/metal-flash-sdpa":

  • single-stream model.generate() crashes with TypeError: 'NoneType' object is not callable (the standard generate path needs flash_attn_func or
    flash_attn_with_kvcache to handle the (B, 1) decode shape β€” neither
    exists), and
  • continuous batching falls back to the varlen path even at decode,
    losing the block-table fast path (the warning at
    modeling_flash_attention_utils.py:197-202 calls this out).

What

Ports the paged_decode_attention_f32 Metal kernel proven in
huggingface/transformers#45977 (the GGUF MPS work). One simdgroup per
(request, head) pair runs a Flash-Decoding-style online softmax,
reading K/V from a paged cache via block_table[batch, t / block_size]
β€” no gather, no contiguous KV materialisation.

Also bundles kv_paged_write_f32 β€” a single-dispatch kernel that
computes the write slot from block_table + seq_lens and writes the
just-produced K/V into the paged cache. Together they cover the
decode step end-to-end.

Status

This PR is a kernel-only contribution: the .metal source is
in place but the .mm C++ binding, torch_binding.{cpp,h} registration,
and Python flash_attn_with_kvcache wrapper still need to be added to
match the package's conventions. I can take those on in a follow-up
once the kernel approach lands.

Validation (from #45977 on M3 Max)

  • Correctness: bit-equivalent to a reference SDPA over a gathered KV
    tensor at (B=4, H_Q=16, head_dim=128, S=50), max abs diff 1.67e-6
    (fp32 rounding noise).
  • Microbench: ~134 Β΅s/call at the shape above. Replaces the
    ~16 ms/layer gather-then-SDPA path in the transformers CB block-table
    decode loop on MPS.
  • End-to-end on Qwen1.5-MoE-A2.7B Q4_K_M, M3 Max,
    generate_batch N=8, MAX_NEW=128:
    • varlen (today's mflash): 121.0 tok/s aggregate
    • block-table + paged decode (this PR): 120.6 tok/s with PyTorch
      SDPA β€” the kernel itself is the unlock; expectation is to match
      or exceed the varlen path once the .mm binding lands.

References: paged_decode_attention_f32 source in
huggingface/transformers#45977 β†’
gguf-dequant-kernels/gguf_dequant_metal/dequantize.metal.

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment