Bundle conflicts log β flash-attn-4-sm120-sncbl
Conflicts and genuine bugs encountered while stacking PRs #2348/2349/2389/2439 on current Dao-AILab/flash-attention main. Each entry: location, what the two sides did, how I resolved it, and whether an individual PR needs a backport fix at merge time.
Bundle validated on SM121a (DGX Spark GB10):
- TMA forward (dense, bf16, causal): max_diff=0.0078
- Paged KV (varlen, bf16, causal, ps=16): max_diff=0.0039
- Dropout (bf16, causal, p=0.1, seed=42): finite, magnitude ratio 1.11
- Plain forward (bf16/fp16, causal={0,1}): max_diff 0.0005-0.008
[SEMANTIC] interface.py SM120 dispatch β #2348 Γ #2349 Γ #2389 Γ #2439
See entry below; addressed during stacking.
[SEMANTIC] flash_fwd.py FlashAttentionForwardSm80.call() β #2348 Γ #2389
See entry below; addressed during stacking.
[MINOR] flash_fwd.py launch()/kernel() args β #2389 Γ #2439
See entry below; addressed during stacking.
[BUG found in #2349] FlashAttentionForwardSm120Tma.call stream position
Location: flash_attn/cute/flash_fwd_sm120_tma.py __call__ signature.
Bug: stream is declared at position 7 (right after softmax_scale)
instead of at the end. The base FlashAttentionForwardSm80.__call__
and FlashAttentionForwardSm100.__call__ both put stream last, with
a comment: "Always keep stream as the last parameter (EnvStream:
obtained implicitly via TVM FFI)". cute.compile binds args
positionally, so the compile_args list (which ends with current_stream)
would pass cu_seqlens_q_tensor where the TMA kernel expects stream.
When #2349 lands alone: Manifests as a runtime error once the
compiled TMA kernel is invoked through _flash_attn_fwd:
DSLRuntimeError: expects argument #20 (dropout_seed_hi) to be one of (Int32, NoneType), but got _FakeStream. The TMA PR branch's own tests
probably route the stream differently or never hit this code path.
Fix in bundle: Moved stream to the end of TMA's call
signature to match the base class. Also added dropout_seed_lo /
dropout_seed_hi parameters that accept and assert they're None
(interface dispatch already gates TMA on dropout_p == 0).
PR backport recommendation: apply the same stream-last fix to #2349 directly. Independent of any other PR β the TMA kernel is broken as-written.
[BUG found during merge] DSL SSA collision on shared locals row_scale / sO
Location: flash_attn/cute/flash_fwd.py inside
FlashAttentionForwardSm80.call(). #2389 adds a block-sparse branch
that defines row_scale and sO. #2348/base dense path also defines
row_scale and sO inside a runtime if n_block_max > n_block_min:
gate. When both branches coexist, the DSL's SSA analysis sees the same
variable name assigned in both a compile-time branch (block-sparse)
and a dynamic-if branch (dense), then rejects the dynamic assignment:
"row_scale is None prior to this if, and update to _Tensor inside of this if is not supported."
Fix in bundle: renamed block-sparse locals to bs_row_scale /
bs_sO. They're local-use-only, so the rename has no downstream
effect.
Also: replaced my initial combined gate
if const_expr(blocksparse_tensors is None) and n_block_max > n_block_min:
with nested if const_expr(blocksparse_tensors is None): / if n_block_max > n_block_min:,
so the DSL evaluates the const_expr purely at compile time.
PR backport recommendation: this only surfaces when #2348 and #2389 are both applied. Whichever PR merges second needs to rename its row_scale/sO to avoid collision and nest the conditions. Either PR can absorb the fix; easier to do in #2389 since it already introduces the block-sparse branch.
[SEMANTIC] interface.py SM120 dispatch
Three-way conflict (same location, four PRs touch it):
- #2348 sets
num_stages=2unconditionally, gates on no block-sparse. - #2349 adds
FlashAttentionForwardSm120Tmadispatch. - #2389 enables block-sparse on SM120 β conflicts with #2348's assert.
- #2439 passes
p_dropout=dropout_pto SM120 constructor; TMA kernel doesn't implement dropout.
Bundle resolution:
is_varlen = cu_seqlens_q is not None or cu_seqlens_k is not None
use_tma_sm120 = (
page_table is None
and not is_varlen
and not use_block_sparsity
and dropout_p == 0.0
)
if use_tma_sm120 and FlashAttentionForwardSm120Tma.can_implement(...):
fa_fwd = FlashAttentionForwardSm120Tma(...)
else:
num_stages_sm120 = 2 if page_table is not None else 1
fa_fwd = FlashAttentionForwardSm120(
..., num_stages=num_stages_sm120, ..., p_dropout=dropout_p,
)
PR backport at merge: whichever PR lands LAST among the four needs to incorporate the combined dispatch. ~5-10 lines each. Flag in PR descriptions now.
[SEMANTIC] flash_fwd.py call() β #2348 Γ #2389
Conflict: #2348 wraps prologue/mainloop in
if n_block_max > n_block_min: (split-KV empty-range guard) + adds
split_idx + uses n_block - n_block_min bound. #2389 adds a
separate block-sparse mainloop path that completes to its own
epilogue, then re-gates dense on if const_expr(blocksparse_tensors is None).
Bundle resolution:
if const_expr(blocksparse_tensors is not None):
# block-sparse complete flow (bs_row_scale, bs_sO β see BUG note above)
...
if const_expr(blocksparse_tensors is None):
if n_block_max > n_block_min:
# dense flow (split-KV-aware, split_idx in epilogue)
...
PR backport: same guidance as dispatch conflict. Scope ~30-40 lines of code movement. Biggest rebase cost; flag in whichever PR is waiting.
[MINOR] flash_fwd.py launch()/kernel() args β #2389 Γ #2439
Conflict: both add kwargs to the same launch call + kernel signature + interface.py call_args list. Bundle resolution: keep both (blocksparse_tensors, then dropout_nheads/lo/hi). Trivial to backport.
Bundle branch: bundle/sncbl-sm120. Final validated commits:
657e8b5Apply PR #2348 (+#2336)fa631d1Apply PR #2349 + merge with #2348 dispatch (amended: TMA stream-last fix)d921b6aApply PR #2389 + merge with #2348/#2349a865455Apply PR #2439 + merge with #2389/dispatch (amended: bs_ rename + nested gate)
All syntax-clean and validated on SM121a end-to-end through three distinct kernel paths (TMA, paged KV, dropout).
[BUG found in bundle, fixed via #2484] GQA / MQA crd2idx error β flash_fwd_sm120{,_tma}.py
Symptom (pre-fix): any call to flash_attn_func or
flash_attn_varlen_func on SM120 with qhead_per_kvhead > 1 (i.e.,
real GQA / MQA workloads like Qwen3, LLaMA3) failed at compile time
with:
loc("tPrPtr[i] = utils.elem_pointer(tensor, ((h_idx, m_idx),)).toint()"
("flash_attn/cute/pack_gqa.py":139:20)):
error: unable to compute crd2idx with
'!cute.layout<"(?):(?{i64 div=8})">' and
'!cute.coord<"((?,?))">'
Cause: Sm80.__call__'s epilogue (which Sm120 inherits and
Sm120Tma calls via self.epilogue) takes the pack_gqa.store_O
branch when self.pack_gqa is True (default for qhead_per_kvhead > 1
in interface.py). pack_gqa.store_O calls compute_ptr (pack_gqa.py:139)
which expects a packed ((qhead_per_kvhead, seqlen_q), headdim) layout.
Sm90 and Sm100 apply pack_gqa_layout before handing tensors to
PackGQA, but Sm80 does not, so the layout reaching compute_ptr is
un-packed and crd2idx against the hierarchical coord fails. Even adding
the pack_gqa_layout calls is not sufficient because Sm80's mainloop
tile sizing assumes tile_m divides the seqlen dimension cleanly,
which fails when qhead_per_kvhead does not divide tile_m.
Fix: override self.pack_gqa = False in both
FlashAttentionForwardSm120.__init__ and
FlashAttentionForwardSm120Tma.__init__, after super().__init__().
This routes GQA / MQA through the non-packed epilogue branch which is
functionally correct on every shape tested. Tracked upstream as
Dao-AILab/flash-attention#2484.
Validation post-fix: 64 / 64 configurations pass on SM121a (MHA
- GQA Qwen3 + GQA LLaMA3 + MQA, dense + varlen, bf16 + fp16, causal + non-causal, batched). Max diff β€ 0.0156 against PyTorch f32 reference.