flash-attn-4-sm120 / CONFLICTS_LOG.md
blake-snc's picture
Apply Dao #2484 SM120 init-time fix: pack_gqa=False on Sm120/Sm120Tma
07431ad verified

Bundle conflicts log β€” flash-attn-4-sm120-sncbl

Conflicts and genuine bugs encountered while stacking PRs #2348/2349/2389/2439 on current Dao-AILab/flash-attention main. Each entry: location, what the two sides did, how I resolved it, and whether an individual PR needs a backport fix at merge time.

Bundle validated on SM121a (DGX Spark GB10):

  • TMA forward (dense, bf16, causal): max_diff=0.0078
  • Paged KV (varlen, bf16, causal, ps=16): max_diff=0.0039
  • Dropout (bf16, causal, p=0.1, seed=42): finite, magnitude ratio 1.11
  • Plain forward (bf16/fp16, causal={0,1}): max_diff 0.0005-0.008

[SEMANTIC] interface.py SM120 dispatch β€” #2348 Γ— #2349 Γ— #2389 Γ— #2439

See entry below; addressed during stacking.

[SEMANTIC] flash_fwd.py FlashAttentionForwardSm80.call() β€” #2348 Γ— #2389

See entry below; addressed during stacking.

[MINOR] flash_fwd.py launch()/kernel() args β€” #2389 Γ— #2439

See entry below; addressed during stacking.


[BUG found in #2349] FlashAttentionForwardSm120Tma.call stream position

Location: flash_attn/cute/flash_fwd_sm120_tma.py __call__ signature.

Bug: stream is declared at position 7 (right after softmax_scale) instead of at the end. The base FlashAttentionForwardSm80.__call__ and FlashAttentionForwardSm100.__call__ both put stream last, with a comment: "Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI)". cute.compile binds args positionally, so the compile_args list (which ends with current_stream) would pass cu_seqlens_q_tensor where the TMA kernel expects stream.

When #2349 lands alone: Manifests as a runtime error once the compiled TMA kernel is invoked through _flash_attn_fwd: DSLRuntimeError: expects argument #20 (dropout_seed_hi) to be one of (Int32, NoneType), but got _FakeStream. The TMA PR branch's own tests probably route the stream differently or never hit this code path.

Fix in bundle: Moved stream to the end of TMA's call signature to match the base class. Also added dropout_seed_lo / dropout_seed_hi parameters that accept and assert they're None (interface dispatch already gates TMA on dropout_p == 0).

PR backport recommendation: apply the same stream-last fix to #2349 directly. Independent of any other PR β€” the TMA kernel is broken as-written.


[BUG found during merge] DSL SSA collision on shared locals row_scale / sO

Location: flash_attn/cute/flash_fwd.py inside FlashAttentionForwardSm80.call(). #2389 adds a block-sparse branch that defines row_scale and sO. #2348/base dense path also defines row_scale and sO inside a runtime if n_block_max > n_block_min: gate. When both branches coexist, the DSL's SSA analysis sees the same variable name assigned in both a compile-time branch (block-sparse) and a dynamic-if branch (dense), then rejects the dynamic assignment: "row_scale is None prior to this if, and update to _Tensor inside of this if is not supported."

Fix in bundle: renamed block-sparse locals to bs_row_scale / bs_sO. They're local-use-only, so the rename has no downstream effect.

Also: replaced my initial combined gate if const_expr(blocksparse_tensors is None) and n_block_max > n_block_min: with nested if const_expr(blocksparse_tensors is None): / if n_block_max > n_block_min:, so the DSL evaluates the const_expr purely at compile time.

PR backport recommendation: this only surfaces when #2348 and #2389 are both applied. Whichever PR merges second needs to rename its row_scale/sO to avoid collision and nest the conditions. Either PR can absorb the fix; easier to do in #2389 since it already introduces the block-sparse branch.


[SEMANTIC] interface.py SM120 dispatch

Three-way conflict (same location, four PRs touch it):

  • #2348 sets num_stages=2 unconditionally, gates on no block-sparse.
  • #2349 adds FlashAttentionForwardSm120Tma dispatch.
  • #2389 enables block-sparse on SM120 β€” conflicts with #2348's assert.
  • #2439 passes p_dropout=dropout_p to SM120 constructor; TMA kernel doesn't implement dropout.

Bundle resolution:

is_varlen = cu_seqlens_q is not None or cu_seqlens_k is not None
use_tma_sm120 = (
    page_table is None
    and not is_varlen
    and not use_block_sparsity
    and dropout_p == 0.0
)
if use_tma_sm120 and FlashAttentionForwardSm120Tma.can_implement(...):
    fa_fwd = FlashAttentionForwardSm120Tma(...)
else:
    num_stages_sm120 = 2 if page_table is not None else 1
    fa_fwd = FlashAttentionForwardSm120(
        ..., num_stages=num_stages_sm120, ..., p_dropout=dropout_p,
    )

PR backport at merge: whichever PR lands LAST among the four needs to incorporate the combined dispatch. ~5-10 lines each. Flag in PR descriptions now.


[SEMANTIC] flash_fwd.py call() β€” #2348 Γ— #2389

Conflict: #2348 wraps prologue/mainloop in if n_block_max > n_block_min: (split-KV empty-range guard) + adds split_idx + uses n_block - n_block_min bound. #2389 adds a separate block-sparse mainloop path that completes to its own epilogue, then re-gates dense on if const_expr(blocksparse_tensors is None).

Bundle resolution:

if const_expr(blocksparse_tensors is not None):
    # block-sparse complete flow (bs_row_scale, bs_sO β€” see BUG note above)
    ...
if const_expr(blocksparse_tensors is None):
    if n_block_max > n_block_min:
        # dense flow (split-KV-aware, split_idx in epilogue)
        ...

PR backport: same guidance as dispatch conflict. Scope ~30-40 lines of code movement. Biggest rebase cost; flag in whichever PR is waiting.


[MINOR] flash_fwd.py launch()/kernel() args β€” #2389 Γ— #2439

Conflict: both add kwargs to the same launch call + kernel signature + interface.py call_args list. Bundle resolution: keep both (blocksparse_tensors, then dropout_nheads/lo/hi). Trivial to backport.


Bundle branch: bundle/sncbl-sm120. Final validated commits:

  1. 657e8b5 Apply PR #2348 (+#2336)
  2. fa631d1 Apply PR #2349 + merge with #2348 dispatch (amended: TMA stream-last fix)
  3. d921b6a Apply PR #2389 + merge with #2348/#2349
  4. a865455 Apply PR #2439 + merge with #2389/dispatch (amended: bs_ rename + nested gate)

All syntax-clean and validated on SM121a end-to-end through three distinct kernel paths (TMA, paged KV, dropout).


[BUG found in bundle, fixed via #2484] GQA / MQA crd2idx error β€” flash_fwd_sm120{,_tma}.py

Symptom (pre-fix): any call to flash_attn_func or flash_attn_varlen_func on SM120 with qhead_per_kvhead > 1 (i.e., real GQA / MQA workloads like Qwen3, LLaMA3) failed at compile time with:

loc("tPrPtr[i] = utils.elem_pointer(tensor, ((h_idx, m_idx),)).toint()"
    ("flash_attn/cute/pack_gqa.py":139:20)):
error: unable to compute crd2idx with
       '!cute.layout<"(?):(?{i64 div=8})">' and
       '!cute.coord<"((?,?))">'

Cause: Sm80.__call__'s epilogue (which Sm120 inherits and Sm120Tma calls via self.epilogue) takes the pack_gqa.store_O branch when self.pack_gqa is True (default for qhead_per_kvhead > 1 in interface.py). pack_gqa.store_O calls compute_ptr (pack_gqa.py:139) which expects a packed ((qhead_per_kvhead, seqlen_q), headdim) layout. Sm90 and Sm100 apply pack_gqa_layout before handing tensors to PackGQA, but Sm80 does not, so the layout reaching compute_ptr is un-packed and crd2idx against the hierarchical coord fails. Even adding the pack_gqa_layout calls is not sufficient because Sm80's mainloop tile sizing assumes tile_m divides the seqlen dimension cleanly, which fails when qhead_per_kvhead does not divide tile_m.

Fix: override self.pack_gqa = False in both FlashAttentionForwardSm120.__init__ and FlashAttentionForwardSm120Tma.__init__, after super().__init__(). This routes GQA / MQA through the non-packed epilogue branch which is functionally correct on every shape tested. Tracked upstream as Dao-AILab/flash-attention#2484.

Validation post-fix: 64 / 64 configurations pass on SM121a (MHA

  • GQA Qwen3 + GQA LLaMA3 + MQA, dense + varlen, bf16 + fp16, causal + non-causal, batched). Max diff ≀ 0.0156 against PyTorch f32 reference.