drbh commited on
Commit
1f83cde
·
0 Parent(s):

feat: build flash mla with kernel builder

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .bak
2
+ __pycache__
README.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - kernel
4
+ - flash-mla
5
+ - deepseek
6
+ - kernel-builder
7
+ ---
8
+
9
+ # flash-mla
10
+
11
+ This repo builds Deepseeks [FlashMLA](https://github.com/deepseek-ai/FlashMLA) kernel via the HF [kernel-builder](https://github.com/huggingface/kernel-builder)
12
+
13
+ ### Dev
14
+ ```bash
15
+ nix develop -L
16
+ pytest -vv tests/
17
+ ```
18
+
19
+ ### Build
20
+ ```bash
21
+ nix build .#bundle -L
22
+ ```
build.toml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "flash_mla"
3
+
4
+ [torch]
5
+ src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"]
6
+
7
+
8
+ [kernel.activation]
9
+ cuda-capabilities = [
10
+ # "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9",
11
+
12
+ # Only available on H100 and H200
13
+ "9.0", # (Hopper)
14
+ ]
15
+ src = [
16
+ "flash_mla/flash_mla_api.cu",
17
+ "flash_mla/flash_fwd_mla_bf16_sm90.cu",
18
+ "flash_mla/flash_fwd_mla_fp16_sm90.cu",
19
+ "flash_mla/flash_fwd_mla_kernel.h",
20
+ "flash_mla/flash_fwd_mla_metadata.cu",
21
+ "flash_mla/flash_mla.h",
22
+ "flash_mla/named_barrier.h",
23
+ "flash_mla/softmax.h",
24
+ "flash_mla/static_switch.h",
25
+ "flash_mla/utils.h",
26
+ ]
27
+ depends = ["torch", "cutlass_3_6"]
flake.lock ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1733328505,
6
+ "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-utils": {
19
+ "inputs": {
20
+ "systems": "systems"
21
+ },
22
+ "locked": {
23
+ "lastModified": 1731533236,
24
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
25
+ "owner": "numtide",
26
+ "repo": "flake-utils",
27
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
28
+ "type": "github"
29
+ },
30
+ "original": {
31
+ "owner": "numtide",
32
+ "repo": "flake-utils",
33
+ "type": "github"
34
+ }
35
+ },
36
+ "kernel-builder": {
37
+ "inputs": {
38
+ "flake-compat": "flake-compat",
39
+ "flake-utils": "flake-utils",
40
+ "nixpkgs": "nixpkgs",
41
+ "rocm-nix": "rocm-nix"
42
+ },
43
+ "locked": {
44
+ "lastModified": 1740571741,
45
+ "narHash": "sha256-MIy7OBrhz8OqFaLT3/MpvL+IGz720le2Ru2wGYPdswo=",
46
+ "ref": "refs/heads/main",
47
+ "rev": "5062dad8a4818d7239f59e1153b8c18b8b59da74",
48
+ "revCount": 89,
49
+ "type": "git",
50
+ "url": "ssh://git@github.com/huggingface/kernel-builder"
51
+ },
52
+ "original": {
53
+ "type": "git",
54
+ "url": "ssh://git@github.com/huggingface/kernel-builder"
55
+ }
56
+ },
57
+ "nixpkgs": {
58
+ "locked": {
59
+ "lastModified": 1740344854,
60
+ "narHash": "sha256-+TiHtSOo+RPUNrcfkcGXmapJ40O3gt6yOe2nA8y0KPw=",
61
+ "owner": "nixos",
62
+ "repo": "nixpkgs",
63
+ "rev": "0b3aa63c013cf9302afc0ba5dbd81f8fab7bd94f",
64
+ "type": "github"
65
+ },
66
+ "original": {
67
+ "owner": "nixos",
68
+ "ref": "nixos-unstable-small",
69
+ "repo": "nixpkgs",
70
+ "type": "github"
71
+ }
72
+ },
73
+ "rocm-nix": {
74
+ "inputs": {
75
+ "nixpkgs": [
76
+ "kernel-builder",
77
+ "nixpkgs"
78
+ ]
79
+ },
80
+ "locked": {
81
+ "lastModified": 1740473629,
82
+ "narHash": "sha256-xW5RfZScKmFymmwdBSZvNZxvFzQSJu8lgF0cQKomT2E=",
83
+ "owner": "huggingface",
84
+ "repo": "rocm-nix",
85
+ "rev": "e4b51d092caf52c693c330c177369adbf6a153ba",
86
+ "type": "github"
87
+ },
88
+ "original": {
89
+ "owner": "huggingface",
90
+ "repo": "rocm-nix",
91
+ "type": "github"
92
+ }
93
+ },
94
+ "root": {
95
+ "inputs": {
96
+ "kernel-builder": "kernel-builder"
97
+ }
98
+ },
99
+ "systems": {
100
+ "locked": {
101
+ "lastModified": 1681028828,
102
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
103
+ "owner": "nix-systems",
104
+ "repo": "default",
105
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "nix-systems",
110
+ "repo": "default",
111
+ "type": "github"
112
+ }
113
+ }
114
+ },
115
+ "root": "root",
116
+ "version": 7
117
+ }
flake.nix ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for FlashMLA kernel";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "git+ssh://git@github.com/huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs ./.;
14
+ }
flash_mla/flash_fwd_mla_bf16_sm90.cu ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #include "flash_fwd_mla_kernel.h"
2
+
3
+ template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(Flash_fwd_mla_params &params, cudaStream_t stream);
flash_mla/flash_fwd_mla_fp16_sm90.cu ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #include "flash_fwd_mla_kernel.h"
2
+
3
+ template void run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(Flash_fwd_mla_params &params, cudaStream_t stream);
flash_mla/flash_fwd_mla_kernel.h ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cute/tensor.hpp>
4
+ #include <cutlass/cutlass.h>
5
+ #include <cutlass/array.h>
6
+ #include <cutlass/numeric_types.h>
7
+
8
+ using namespace cute;
9
+
10
+ #include "named_barrier.h"
11
+ #include "utils.h"
12
+ #include "softmax.h"
13
+ #include "static_switch.h"
14
+ #include "flash_mla.h"
15
+
16
+
17
+ template<typename PrecType, int DIM, int DIM2 = DIM>
18
+ constexpr auto getSmemLayoutK() {
19
+ constexpr int headSizeBytes = sizeof(PrecType) * DIM;
20
+ constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2;
21
+
22
+ if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) {
23
+ return GMMA::Layout_K_SW128_Atom<PrecType>{};
24
+ } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) {
25
+ return GMMA::Layout_K_SW64_Atom<PrecType>{};
26
+ } else {
27
+ return GMMA::Layout_K_SW32_Atom<PrecType>{};
28
+ }
29
+ }
30
+
31
+ template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t, int kHeadDimV_ = 0>
32
+ struct Flash_fwd_kernel_traits_mla {
33
+ using Element = elem_type;
34
+ using ElementAccum = float;
35
+ using index_t = int64_t;
36
+
37
+ static constexpr int kNWarps = kNWarps_;
38
+ static constexpr int kNThreads = kNWarps * 32;
39
+ static constexpr int kNWarpsS = 4;
40
+ static constexpr int kNThreadsS = kNWarpsS * 32;
41
+
42
+ static constexpr int kBlockM = kBlockM_;
43
+ static constexpr int kBlockN = kBlockN_;
44
+ static constexpr int kHeadDim = kHeadDim_;
45
+ static_assert(kHeadDim % 32 == 0);
46
+ static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim;
47
+ static_assert(kHeadDimV % 32 == 0);
48
+ static_assert(kHeadDimV <= kHeadDim);
49
+ static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
50
+ static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
51
+
52
+ using TiledMma = decltype(make_tiled_mma(
53
+ cute::GMMA::ss_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>,
54
+ GMMA::Major::K, GMMA::Major::K>(),
55
+ Layout<Shape<Int<kNWarpsS / 4>, _1, _1>>{}));
56
+
57
+ static constexpr int AtomLayoutNO = kNThreads / kNThreadsS;
58
+ using TiledMmaO = decltype(make_tiled_mma(
59
+ cute::GMMA::rs_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kHeadDimV / AtomLayoutNO>, Int<kBlockN>>,
60
+ GMMA::Major::K, GMMA::Major::MN>(),
61
+ Layout<Shape<Int<kNWarpsS / 4>, Int<AtomLayoutNO>, _1>>{}));
62
+
63
+ using SmemLayoutQ = decltype(tile_to_shape(
64
+ getSmemLayoutK<Element, kHeadDim>(),
65
+ Shape<Int<kBlockM>, Int<kHeadDim>>{}));
66
+
67
+ using SmemLayoutK = decltype(tile_to_shape(
68
+ getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
69
+ Shape<Int<kBlockN>, Int<kHeadDim>>{}));
70
+
71
+ using SmemLayoutV = decltype(tile_to_shape(
72
+ getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
73
+ Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
74
+ using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
75
+
76
+ using SmemLayoutP = Layout<Shape<Shape<_2, _2>, Int<kNThreadsS>, _1, Int<kBlockN / 8>>>;
77
+ using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
78
+
79
+ using SmemLayoutAtomO = decltype(composition(
80
+ Swizzle<kSwizzle, 3, 3>{},
81
+ Layout<Shape<Int<8>, Int<kBlockKSmem>>, Stride<Int<kBlockKSmem>, _1>>{}));
82
+ using SmemLayoutO = decltype(tile_to_shape(
83
+ SmemLayoutAtomO{},
84
+ Shape<Int<kBlockM>, Int<kHeadDimV>>{}));
85
+ using SmemCopyAtomO = Copy_Atom<SM90_U32x4_STSM_N, Element>;
86
+ using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
87
+
88
+ static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
89
+ static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
90
+ static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
91
+ using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>;
92
+ static constexpr int kNThreadsLoad = kNThreads - kNThreadsS;
93
+ static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
94
+
95
+ using GmemLayoutAtom = Layout<
96
+ Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
97
+ Stride<Int<kGmemThreadsPerRow>, _1>>;
98
+ using GmemTiledCopy = decltype(make_tiled_copy(
99
+ Copy_Atom<Gmem_copy_struct, Element>{},
100
+ GmemLayoutAtom{},
101
+ Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
102
+
103
+ using GmemLayoutAtomO = Layout<
104
+ Shape<Int<kNThreadsS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
105
+ Stride<Int<kGmemThreadsPerRow>, _1>>;
106
+ using GmemTiledCopyO = decltype(make_tiled_copy(
107
+ Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
108
+ GmemLayoutAtomO{},
109
+ Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
110
+
111
+ static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
112
+ static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum;
113
+ using GmemLayoutAtomOaccum = Layout<
114
+ Shape<Int<kNThreadsS / kGmemThreadsPerRowAccum>, Int<kGmemThreadsPerRowAccum>>,
115
+ Stride<Int<kGmemThreadsPerRowAccum>, _1>>;
116
+ using GmemTiledCopyOaccum = decltype(make_tiled_copy(
117
+ Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
118
+ GmemLayoutAtomOaccum{},
119
+ Layout<Shape<_1, _4>>{})); // Val layout, 4 vals per store
120
+ };
121
+
122
+ namespace flash {
123
+
124
+ using namespace cute;
125
+
126
+ template<typename Kernel_traits>
127
+ struct SharedStorageMLA {
128
+ union {
129
+ struct {
130
+ cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
131
+ cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutK> * 2> smem_k; // Double buffer
132
+ cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
133
+ cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_scale;
134
+ };
135
+ struct {
136
+ cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_max;
137
+ cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_sum;
138
+ cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;
139
+ };
140
+ };
141
+ };
142
+
143
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
144
+
145
+ template<typename Kernel_traits, bool Split, typename SharedStorage, typename AccO, typename Softmax>
146
+ __forceinline__ __device__ void store(const Flash_fwd_mla_params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx,
147
+ SharedStorage &shared_storage, AccO tOrO, Softmax softmax) {
148
+ constexpr int kBlockM = Kernel_traits::kBlockM;
149
+ constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
150
+ constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
151
+ using Element = typename Kernel_traits::Element;
152
+ using ElementAccum = typename Kernel_traits::ElementAccum;
153
+ using index_t = typename Kernel_traits::index_t;
154
+
155
+ const int tidx = threadIdx.x;
156
+
157
+ typename Kernel_traits::TiledMmaO tiled_mma_o;
158
+ auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
159
+
160
+ // Epilogue
161
+
162
+ const int split_offset = __ldg(params.num_splits_ptr + bidb);
163
+
164
+ Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(tOrO, params.scale_softmax);
165
+
166
+ using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
167
+ Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
168
+ // Partition sO to match the accumulator partitioning
169
+ using SmemTiledCopyO = std::conditional_t<
170
+ !Split,
171
+ typename Kernel_traits::SmemCopyAtomO,
172
+ typename Kernel_traits::SmemCopyAtomOaccum
173
+ >;
174
+ auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o);
175
+ auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
176
+ Tensor rO = flash::convert_type<ElementO>(tOrO);
177
+ Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
178
+ Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
179
+
180
+ __syncthreads();
181
+
182
+ cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
183
+
184
+ const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
185
+ const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v;
186
+ const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
187
+ const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
188
+
189
+ Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
190
+ Shape<Int<kBlockM>, Int<kHeadDimV>>{},
191
+ make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
192
+ Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)),
193
+ Shape<Int<kBlockM>>{}, Stride<_1>{});
194
+
195
+ using GmemTiledCopyO = std::conditional_t<!Split, typename Kernel_traits::GmemTiledCopyO, typename Kernel_traits::GmemTiledCopyOaccum>;
196
+ GmemTiledCopyO gmem_tiled_copy_Oaccum;
197
+ auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
198
+ Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
199
+ Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
200
+
201
+ __syncthreads();
202
+
203
+ if (tidx >= kNThreadsS) { return; }
204
+
205
+ Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
206
+ cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
207
+
208
+ Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
209
+ Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1)
210
+ Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0);
211
+ CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
212
+ if (get<1>(taccOcO_row(0)) == 0) {
213
+ #pragma unroll
214
+ for (int mi = 0; mi < size(lse); ++mi) {
215
+ const int row = get<0>(taccOcO_row(mi));
216
+ if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
217
+ }
218
+ }
219
+
220
+ // Construct identity layout for sO
221
+ Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
222
+ // Repeat the partitioning with identity layouts
223
+ Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
224
+ Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
225
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
226
+ flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
227
+ gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM
228
+ );
229
+ }
230
+
231
+ template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
232
+ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params &params,
233
+ const int bidb, const int bidh, const int m_block,
234
+ const int n_split_idx, const int seqlen_k,
235
+ const int n_block_min, const int n_block_max, const bool NoSplit,
236
+ SharedStorage &shared_storage) {
237
+ constexpr int kBlockM = Kernel_traits::kBlockM;
238
+ constexpr int kBlockN = Kernel_traits::kBlockN;
239
+ constexpr int kHeadDim = Kernel_traits::kHeadDim;
240
+ constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
241
+ constexpr int kNThreads = Kernel_traits::kNThreads;
242
+ constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
243
+ static_assert(kNThreads == 256 and kNThreadsS == 128);
244
+ using Element = typename Kernel_traits::Element;
245
+ using index_t = typename Kernel_traits::index_t;
246
+
247
+ const int tidx = threadIdx.x;
248
+ int n_block = n_block_max - 1;
249
+
250
+ Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{});
251
+ Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
252
+ Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
253
+ Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{});
254
+
255
+ Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
256
+ Tensor tPsP = sP(_, tidx % kNThreadsS, _, _);
257
+ Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{});
258
+ Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS);
259
+ Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{});
260
+ Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS);
261
+ Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{});
262
+ Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS);
263
+
264
+ typename Kernel_traits::TiledMmaO tiled_mma_o;
265
+ auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
266
+ Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N)
267
+ Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // ((MMA=4, X), MMA_M, MMA_N=1)
268
+ clear(tOrO);
269
+
270
+ flash::Softmax<2 * size<1>(tOrO)> softmax;
271
+
272
+ int warp_group_idx = cutlass::canonical_warp_group_idx();
273
+ if (warp_group_idx == 0) {
274
+ typename Kernel_traits::TiledMma tiled_mma;
275
+ auto thr_mma = tiled_mma.get_thread_slice(tidx);
276
+ Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
277
+ Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
278
+
279
+ if (n_block % 2 == 1) {
280
+ // Double buffer for sK
281
+ constexpr int sK_offset = size(sK);
282
+ tSrK.data() = tSrK.data() + sK_offset / 8;
283
+ tOrVt.data() = tOrVt.data() + sK_offset / 8;
284
+ }
285
+
286
+ // We need masking on S for the very last block when K and V has length not multiple of kBlockN.
287
+ // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
288
+ // We will have at least 1 "masking" iteration.
289
+ // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
290
+ // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
291
+ constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
292
+ #pragma unroll 1
293
+ for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) {
294
+ __syncthreads();
295
+
296
+ Tensor tSrS = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // ((MMA=4, X), MMA_M, MMA_N=1)
297
+ flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma, tSrQ, tSrK, tSrS);
298
+
299
+ const bool is_masking_step = masking_step > 0;
300
+ const bool is_first_masking_step = masking_step == n_masking_steps;
301
+
302
+ if (is_masking_step) {
303
+ Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
304
+ Tensor tScS = thr_mma.partition_C(cS);
305
+ #pragma unroll
306
+ for (int i = 0; i < size(tSrS); ++i) {
307
+ if constexpr (!Is_causal) { // Just masking based on col
308
+ if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) tSrS(i) = -INFINITY;
309
+ } else {
310
+ // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
311
+ // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
312
+ int row = int(get<0>(tScS(i)));
313
+ int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups;
314
+ if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY;
315
+ }
316
+ }
317
+ }
318
+
319
+ // We have key_padding_mask so we'll need to Check_inf
320
+ Tensor scale_o = is_first_masking_step
321
+ ? softmax.template softmax</*Is_first=*/true, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
322
+ : is_masking_step ?
323
+ softmax.template softmax</*Is_first=*/false, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
324
+ : softmax.template softmax</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, params.scale_softmax_log2);
325
+
326
+ Tensor rP = flash::convert_type<Element>(tSrS);
327
+ cute::copy(rP, tPsP);
328
+ cute::copy(scale_o, tScale_osScale_o);
329
+
330
+ cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SReady));
331
+
332
+ flash::rescale_o(tOrO, scale_o);
333
+
334
+ Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
335
+ flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
336
+
337
+ // Double buffer for sK
338
+ const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
339
+ tSrK.data() = tSrK.data() + sK_offset / 8;
340
+ tOrVt.data() = tOrVt.data() + sK_offset / 8;
341
+ }
342
+
343
+ cute::copy(softmax.row_max, tRow_maxsRow_max);
344
+ cute::copy(softmax.row_sum, tRow_sumsRow_sum);
345
+ cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
346
+ } else {
347
+ const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
348
+ int cur_block_table = __ldg(&block_table[n_block]);
349
+
350
+ const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
351
+ Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
352
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
353
+ make_stride(params.q_row_stride, _1{}));
354
+ typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q;
355
+ auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS);
356
+ Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
357
+ Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
358
+ Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
359
+ Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
360
+ Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
361
+
362
+ // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
363
+ flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
364
+ params.seqlen_q - m_block * kBlockM);
365
+
366
+ const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride;
367
+ Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
368
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
369
+ make_stride(params.k_row_stride, _1{}));
370
+ typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K;
371
+ auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS);
372
+ Tensor tKgK = gmem_thr_copy_K.partition_S(gK);
373
+ Tensor tKsK = gmem_thr_copy_K.partition_D(sK);
374
+ Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
375
+ Tensor tKcK = gmem_thr_copy_K.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
376
+ Tensor tKpK = make_tensor<bool>(make_shape(size<2>(tKsK)));
377
+
378
+ if (n_block % 2 == 1) {
379
+ // Double buffer for sK
380
+ constexpr int sK_offset = size(sK);
381
+ tKsK.data() = tKsK.data() + sK_offset;
382
+ tOrVt.data() = tOrVt.data() + sK_offset / 8;
383
+ }
384
+
385
+ // We need to clear the sK smem tiles because K is V.
386
+ const index_t offset_k = cur_block_table * params.k_batch_stride;
387
+ tKgK.data() = tKgK.data() + offset_k;
388
+ flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK,
389
+ seqlen_k - n_block * kBlockN);
390
+ tKgK.data() = tKgK.data() + -offset_k;
391
+ cute::cp_async_fence();
392
+
393
+ if (n_block - 1 >= n_block_min) {
394
+ cur_block_table = __ldg(&block_table[n_block - 1]);
395
+ }
396
+
397
+ #pragma unroll 1
398
+ for (; n_block >= n_block_min; --n_block) {
399
+ flash::cp_async_wait<0>();
400
+ __syncthreads();
401
+
402
+ if (n_block - 1 >= n_block_min) {
403
+ // Double buffer for sK
404
+ const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
405
+ tKsK.data() = tKsK.data() + sK_offset;
406
+
407
+ const index_t offset_k = cur_block_table * params.k_batch_stride;
408
+ tKgK.data() = tKgK.data() + offset_k;
409
+ flash::copy</*Is_even_MN=*/true, /*Is_even_K=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK);
410
+ tKgK.data() = tKgK.data() + -offset_k;
411
+ cute::cp_async_fence();
412
+ }
413
+
414
+ cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SReady));
415
+
416
+ if (n_block - 2 >= n_block_min) {
417
+ cur_block_table = __ldg(&block_table[n_block - 2]);
418
+ }
419
+
420
+ typename Kernel_traits::TiledMma tiled_mma;
421
+ auto tSrS_layout = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}).layout();
422
+ Tensor rP = make_tensor<Element>(tSrS_layout);
423
+ Tensor scale_o = make_tensor<float>(Shape<_2>{});
424
+ cute::copy(tScale_osScale_o, scale_o);
425
+ cute::copy(tPsP, rP);
426
+
427
+ flash::rescale_o(tOrO, scale_o);
428
+
429
+ Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
430
+ flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
431
+
432
+ // Double buffer for sK
433
+ const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
434
+ tOrVt.data() = tOrVt.data() + sK_offset / 8;
435
+ }
436
+
437
+ cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
438
+ cute::copy(tRow_maxsRow_max, softmax.row_max);
439
+ cute::copy(tRow_sumsRow_sum, softmax.row_sum);
440
+ }
441
+
442
+ if (NoSplit)
443
+ store<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
444
+ else
445
+ store<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
446
+ }
447
+
448
+ template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
449
+ __global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1)
450
+ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
451
+ constexpr int kBlockN = Kernel_traits::kBlockN;
452
+ const int m_block = blockIdx.x;
453
+ const int bidh = blockIdx.y;
454
+ const int partition_idx = blockIdx.z;
455
+
456
+ extern __shared__ char shared_memory[];
457
+ auto &shared_storage = *reinterpret_cast<SharedStorage *>(shared_memory);
458
+
459
+ int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
460
+ int4 tile_scheduler_metadata = __ldg(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
461
+ int begin_idx = tile_scheduler_metadata.x;
462
+ int begin_seqlen = tile_scheduler_metadata.y;
463
+ int end_idx = tile_scheduler_metadata.z;
464
+ int end_seqlen = tile_scheduler_metadata.w;
465
+ if (begin_idx >= params.b) return;
466
+ int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
467
+
468
+ #pragma unroll 1
469
+ for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) {
470
+ const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0;
471
+ const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id);
472
+ const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0;
473
+ const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN);
474
+ const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN);
475
+ if (batch_id > begin_idx) {
476
+ __syncthreads(); // Barrier between two tiles.
477
+ }
478
+ flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage);
479
+ }
480
+ }
481
+
482
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
483
+
484
+ template<typename Element, typename ElementAccum, typename index_t, int kHeadDimV, int kMaxSplits>
485
+ __global__ void __launch_bounds__(256, 1, 1)
486
+ flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
487
+ constexpr int kNThreads = 128;
488
+
489
+ const int tidx = threadIdx.x;
490
+ const int bidx = blockIdx.x;
491
+ const int hs = params.h * params.seqlen_q;
492
+ const int batch_idx = bidx / hs;
493
+ const int hs_idx = bidx % hs;
494
+
495
+ const int split_offset = __ldg(params.num_splits_ptr + batch_idx);
496
+ const int actual_num_splits = __ldg(params.num_splits_ptr + batch_idx + 1) - split_offset;
497
+ FLASH_DEVICE_ASSERT(actual_num_splits <= kMaxSplits);
498
+ if (actual_num_splits == 1) return;
499
+
500
+ __shared__ ElementAccum sLseScale[kMaxSplits];
501
+
502
+ const index_t row_offset_lseaccum = split_offset * hs + hs_idx;
503
+ const index_t row_offset_lse = bidx;
504
+ Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lseaccum),
505
+ Shape<Int<kMaxSplits>>{}, make_stride(hs));
506
+ Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
507
+ Shape<_1>{}, Stride<_1>{});
508
+
509
+ int warp_idx = cutlass::canonical_warp_idx_sync();
510
+ if (warp_idx == 0) {
511
+ constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
512
+
513
+ float local_lse[kNLsePerThread];
514
+ for (int i = 0; i < kNLsePerThread; ++i) {
515
+ const int split = i * 32 + tidx;
516
+ local_lse[i] = split < actual_num_splits ? gLSEaccum(split) : -INFINITY;
517
+ }
518
+
519
+ float max_lse = -INFINITY;
520
+ for (int i = 0; i < kNLsePerThread; ++i) max_lse = max(max_lse, local_lse[i]);
521
+ for (int offset = 16; offset >= 1; offset /= 2) max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset));
522
+ max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf
523
+
524
+ float sum_lse = 0;
525
+ for (int i = 0; i < kNLsePerThread; ++i) sum_lse = sum_lse + expf(local_lse[i] - max_lse);
526
+ for (int offset = 16; offset >= 1; offset /= 2) sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset);
527
+
528
+ float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : logf(sum_lse) + max_lse;
529
+ if (tidx == 0) gLSE(0) = global_lse;
530
+
531
+ for (int i = 0; i < kNLsePerThread; ++i) {
532
+ const int split = i * 32 + tidx;
533
+ if (split < actual_num_splits) sLseScale[split] = expf(local_lse[i] - global_lse);
534
+ }
535
+ }
536
+ __syncthreads();
537
+
538
+ static_assert(kHeadDimV % kNThreads == 0);
539
+ constexpr int Elements = kHeadDimV / kNThreads;
540
+ const index_t row_offset_oaccum = (split_offset * hs + hs_idx) * kHeadDimV;
541
+ Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
542
+ Shape<Int<kHeadDimV>>{}, Stride<_1>{});
543
+ using GmemTiledCopyOaccum = decltype(make_tiled_copy(
544
+ Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
545
+ Layout<Shape<Int<kNThreads>>>{},
546
+ Layout<Shape<Int<Elements>>>{}));
547
+ GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
548
+ auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
549
+ Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
550
+ Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
551
+ Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
552
+ clear(tOrO);
553
+
554
+ for (int split = 0; split < actual_num_splits; ++split) {
555
+ cute::copy(tOgOaccum, tOrOaccum);
556
+ ElementAccum lse_scale = sLseScale[split];
557
+ for (int i = 0; i < size(tOrO); ++i) {
558
+ tOrO(i) += lse_scale * tOrOaccum(i);
559
+ }
560
+ tOgOaccum.data() = tOgOaccum.data() + hs * kHeadDimV;
561
+ }
562
+
563
+ Tensor rO = flash::convert_type<Element>(tOrO);
564
+ const int head_idx = (bidx - batch_idx * hs) / params.seqlen_q;
565
+ const int row = bidx - batch_idx * hs - head_idx * params.seqlen_q;
566
+ auto o_ptr = reinterpret_cast<Element *>(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride;
567
+ Tensor gO = make_tensor(make_gmem_ptr(o_ptr + tidx * Elements), Shape<Int<decltype(size<0>(rO))::value>>{}, Stride<_1>{});
568
+ cute::copy(rO, gO);
569
+ }
570
+
571
+ } // namespace flash
572
+
573
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
574
+
575
+ template<typename Kernel_traits, typename SharedStorage>
576
+ void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params &params, cudaStream_t stream) {
577
+ FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN);
578
+ const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
579
+ BOOL_SWITCH(params.is_causal, Is_causal, [&] {
580
+ auto kernel = &flash::flash_fwd_splitkv_mla_kernel<Kernel_traits, Is_causal, SharedStorage>;
581
+ constexpr size_t smem_size = sizeof(SharedStorage);
582
+ CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
583
+ kernel<<<dim3(num_m_block, params.h, params.num_sm_parts), Kernel_traits::kNThreads, smem_size, stream>>>(params);
584
+ });
585
+ CHECK_CUDA_KERNEL_LAUNCH();
586
+
587
+ dim3 grid_combine(params.b * params.h * params.seqlen_q);
588
+ MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] {
589
+ auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel<
590
+ typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>;
591
+ combine_kernel<<<grid_combine, 128, 0, stream>>>(params);
592
+ });
593
+ CHECK_CUDA_KERNEL_LAUNCH();
594
+ }
595
+
596
+ template<typename T, int Headdim>
597
+ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params &params, cudaStream_t stream) {
598
+ static_assert(Headdim == 576);
599
+ FLASH_ASSERT(params.d_v == 512);
600
+ FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV
601
+ using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>;
602
+ run_flash_splitkv_fwd_mla<Kernel_traits, flash::SharedStorageMLA<Kernel_traits>>(params, stream);
603
+ }
flash_mla/flash_fwd_mla_metadata.cu ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "flash_fwd_mla_kernel.h"
2
+
3
+ static constexpr int MaxBatchSize = 4096;
4
+
5
+ __global__ void __launch_bounds__(256, 1, 1)
6
+ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
7
+ int *seqlens_k_ptr = params.seqlens_k_ptr;
8
+ int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
9
+ int *num_splits_ptr = params.num_splits_ptr;
10
+ int batch_size = params.batch_size;
11
+ int block_size_n = params.block_size_n;
12
+ int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
13
+ int num_sm_parts = params.num_sm_parts;
14
+
15
+ __shared__ int num_blocks_shared[MaxBatchSize];
16
+ __shared__ int num_splits_shared[MaxBatchSize];
17
+
18
+ int total_num_blocks = 0;
19
+ for (int i = threadIdx.x; i < batch_size; i += 32) {
20
+ int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n);
21
+ total_num_blocks += num_blocks + fixed_overhead_num_blocks;
22
+ num_blocks_shared[i] = num_blocks;
23
+ }
24
+ for (int offset = 16; offset >= 1; offset /= 2) {
25
+ total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset);
26
+ }
27
+ __syncwarp();
28
+
29
+ if (threadIdx.x == 0) {
30
+ int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
31
+
32
+ int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
33
+ num_splits_shared[0] = 0;
34
+ for (int i = 0; i < num_sm_parts; ++i) {
35
+ int tile_scheduler_metadata0[4], tile_scheduler_metadata1;
36
+ tile_scheduler_metadata0[0] = now_idx;
37
+ tile_scheduler_metadata0[1] = now_block * block_size_n;
38
+ tile_scheduler_metadata1 = now_n_split_idx;
39
+ int remain_payload = payload;
40
+ while (now_idx < batch_size) {
41
+ int num_blocks = num_blocks_shared[now_idx];
42
+ int now_remain_blocks = num_blocks - now_block;
43
+ if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) {
44
+ cum_num_splits += now_n_split_idx + 1;
45
+ num_splits_shared[now_idx + 1] = cum_num_splits;
46
+ remain_payload -= now_remain_blocks + fixed_overhead_num_blocks;
47
+ ++now_idx;
48
+ now_block = 0;
49
+ now_n_split_idx = 0;
50
+ } else {
51
+ if (remain_payload - fixed_overhead_num_blocks > 0) {
52
+ now_block += remain_payload - fixed_overhead_num_blocks;
53
+ ++now_n_split_idx;
54
+ remain_payload = 0;
55
+ }
56
+ break;
57
+ }
58
+ }
59
+ tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1;
60
+ tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1];
61
+ *reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast<int4 *>(tile_scheduler_metadata0);
62
+ tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1;
63
+ }
64
+ FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
65
+ }
66
+ __syncwarp();
67
+
68
+ for (int i = threadIdx.x; i <= batch_size; i += 32) {
69
+ num_splits_ptr[i] = num_splits_shared[i];
70
+ }
71
+ }
72
+
73
+ void get_mla_metadata_func(Mla_metadata_params &params, cudaStream_t stream) {
74
+ FLASH_ASSERT(params.batch_size < MaxBatchSize);
75
+ get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params);
76
+ CHECK_CUDA_KERNEL_LAUNCH();
77
+ }
flash_mla/flash_mla.h ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
4
+
5
+ struct Flash_fwd_mla_params {
6
+ using index_t = int64_t;
7
+
8
+ int b, seqlen_q, d, d_v;
9
+ int h, h_h_k_ratio, ngroups;
10
+ bool is_causal;
11
+ float scale_softmax, scale_softmax_log2;
12
+ int *__restrict__ cu_seqlens_k;
13
+
14
+ void *__restrict__ q_ptr;
15
+ void *__restrict__ k_ptr;
16
+ void *__restrict__ v_ptr;
17
+ void *__restrict__ o_ptr;
18
+ void *__restrict__ softmax_lse_ptr;
19
+
20
+ index_t q_batch_stride;
21
+ index_t k_batch_stride;
22
+ index_t v_batch_stride;
23
+ index_t o_batch_stride;
24
+ index_t q_row_stride;
25
+ index_t k_row_stride;
26
+ index_t v_row_stride;
27
+ index_t o_row_stride;
28
+ index_t q_head_stride;
29
+ index_t k_head_stride;
30
+ index_t v_head_stride;
31
+ index_t o_head_stride;
32
+
33
+ int *__restrict__ block_table;
34
+ index_t block_table_batch_stride;
35
+ int page_block_size;
36
+
37
+ int *__restrict__ tile_scheduler_metadata_ptr;
38
+ int num_sm_parts;
39
+ int *__restrict__ num_splits_ptr;
40
+
41
+ void *__restrict__ softmax_lseaccum_ptr;
42
+ void *__restrict__ oaccum_ptr;
43
+ };
44
+
45
+ static constexpr int TileSchedulerMetaDataSize = 8;
46
+ // [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _]
47
+
48
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
49
+
50
+ template<typename T, int Headdim>
51
+ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params &params, cudaStream_t stream);
52
+
53
+ struct Mla_metadata_params {
54
+ int *__restrict__ seqlens_k_ptr;
55
+ int *__restrict__ tile_scheduler_metadata_ptr;
56
+ int *__restrict__ num_splits_ptr;
57
+ int batch_size;
58
+ int block_size_n;
59
+ int fixed_overhead_num_blocks;
60
+ int num_sm_parts;
61
+ };
62
+
63
+ void get_mla_metadata_func(Mla_metadata_params &params, cudaStream_t stream);
flash_mla/flash_mla_api.cu ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/cuda/CUDAContext.h>
2
+ #include <c10/cuda/CUDAGuard.h>
3
+ #include <torch/all.h>
4
+
5
+
6
+ #include <cutlass/fast_math.h>
7
+
8
+ #include "flash_mla.h"
9
+ #include "static_switch.h"
10
+
11
+ #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
12
+ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
13
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
14
+
15
+
16
+ //
17
+
18
+
19
+ // #include <cmath>
20
+
21
+ // #include "cute/tensor.hpp"
22
+ #include <cute/tensor.hpp>
23
+
24
+ // __global__ void relu_kernel(float *__restrict__ out,
25
+ // float const *__restrict__ input,
26
+ // const int d) {
27
+ // const int64_t token_idx = blockIdx.x;
28
+ // for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
29
+ // auto x = input[token_idx * d + idx];
30
+ // out[token_idx * d + idx] = x > 0.0f ? x : 0.0f;
31
+ // }
32
+ // }
33
+
34
+ // void relu(torch::Tensor &out,
35
+ // torch::Tensor const &input)
36
+ // {
37
+ // TORCH_CHECK(input.scalar_type() == at::ScalarType::Float &&
38
+ // input.scalar_type() == at::ScalarType::Float,
39
+ // "relu_kernel only supports float32");
40
+
41
+ // int d = input.size(-1);
42
+ // int64_t num_tokens = input.numel() / d;
43
+ // dim3 grid(num_tokens);
44
+ // dim3 block(std::min(d, 1024));
45
+ // const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
46
+ // const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
47
+ // relu_kernel<<<grid, block, 0, stream>>>(out.data_ptr<float>(),
48
+ // input.data_ptr<float>(), d);
49
+ // }
50
+
51
+ std::vector<at::Tensor>
52
+ get_mla_metadata(
53
+ at::Tensor &seqlens_k,
54
+ const int64_t num_heads_per_head_k,
55
+ const int64_t num_heads_k
56
+ ) {
57
+ // This should match the logic in the MLA kernel.
58
+ static constexpr int block_size_m = 64;
59
+ static constexpr int block_size_n = 64;
60
+ static constexpr int fixed_overhead_num_blocks = 5;
61
+
62
+ CHECK_DEVICE(seqlens_k);
63
+ TORCH_CHECK(seqlens_k.is_contiguous());
64
+ TORCH_CHECK(seqlens_k.dtype() == torch::kInt32);
65
+
66
+ int batch_size = seqlens_k.size(0);
67
+ int *seqlens_k_ptr = seqlens_k.data_ptr<int>();
68
+ auto options = seqlens_k.options();
69
+
70
+ auto dprops = at::cuda::getCurrentDeviceProperties();
71
+ int sm_count = dprops->multiProcessorCount;
72
+ int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m);
73
+
74
+ auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options);
75
+ auto num_splits = torch::empty({batch_size + 1}, options);
76
+ int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
77
+ int *num_splits_ptr = num_splits.data_ptr<int>();
78
+
79
+ at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()};
80
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
81
+ Mla_metadata_params params = {};
82
+ params.seqlens_k_ptr = seqlens_k_ptr;
83
+ params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr;
84
+ params.num_splits_ptr = num_splits_ptr;
85
+ params.batch_size = batch_size;
86
+ params.block_size_n = block_size_n;
87
+ params.fixed_overhead_num_blocks = fixed_overhead_num_blocks;
88
+ params.num_sm_parts = num_sm_parts;
89
+ get_mla_metadata_func(params, stream);
90
+
91
+ return {tile_scheduler_metadata, num_splits};
92
+ }
93
+
94
+ std::vector<at::Tensor>
95
+ mha_fwd_kvcache_mla(
96
+ at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
97
+ const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
98
+
99
+ // TODO: fix for optional
100
+ // std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
101
+
102
+ const at::Tensor &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
103
+ const int64_t head_size_v,
104
+ const at::Tensor &seqlens_k, // batch_size
105
+ const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
106
+ // TODO: should be float
107
+ const double softmax_scale,
108
+ const bool is_causal_,
109
+ const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
110
+ const at::Tensor &num_splits, // batch_size + 1
111
+
112
+ // TODO: remove this once determined why build is adding this parameter
113
+ const int64_t unknown_param
114
+ ) {
115
+ auto dprops = at::cuda::getCurrentDeviceProperties();
116
+ bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
117
+ TORCH_CHECK(is_sm90);
118
+
119
+ // TODO: fix for mutable bool
120
+ bool is_causal = is_causal_;
121
+
122
+
123
+ // TODO: fix for optional
124
+ // at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
125
+ at::Tensor vcache = vcache_;
126
+
127
+ auto q_dtype = q.dtype();
128
+ TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
129
+
130
+ CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
131
+
132
+ TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
133
+ TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
134
+ TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
135
+
136
+ CHECK_DEVICE(block_table);
137
+ TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
138
+ TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
139
+
140
+ const auto sizes = q.sizes();
141
+ const int batch_size = sizes[0];
142
+ const int seqlen_q_ori = sizes[1];
143
+ const int num_heads_ori = sizes[2];
144
+ const int head_size = sizes[3];
145
+ TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
146
+ TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32");
147
+
148
+ const int max_num_blocks_per_seq = block_table.size(1);
149
+ const int num_blocks = kcache.size(0);
150
+ const int page_block_size = kcache.size(1);
151
+ const int num_heads_k = kcache.size(2);
152
+ TORCH_CHECK(batch_size > 0, "batch size must be postive");
153
+ TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
154
+
155
+ if (seqlen_q_ori == 1) { is_causal = false; }
156
+
157
+ const int ngroups = num_heads_ori / num_heads_k;
158
+ const int seqlen_q = seqlen_q_ori * ngroups;
159
+ const int num_heads = num_heads_k;
160
+ q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3)
161
+ .reshape({batch_size, seqlen_q, num_heads, head_size});
162
+
163
+ int head_size_k = head_size;
164
+ CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
165
+ CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
166
+
167
+ // TODO: fix for optional
168
+ // if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); }
169
+ CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v);
170
+
171
+ CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
172
+
173
+
174
+ TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
175
+ CHECK_DEVICE(seqlens_k);
176
+ CHECK_CONTIGUOUS(seqlens_k);
177
+ CHECK_SHAPE(seqlens_k, batch_size);
178
+
179
+ at::cuda::CUDAGuard device_guard{(char)q.get_device()};
180
+
181
+ auto opts = q.options();
182
+ at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts);
183
+ at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
184
+
185
+ Flash_fwd_mla_params params = {};
186
+ // Set the sizes.
187
+ params.b = batch_size;
188
+ params.seqlen_q = seqlen_q;
189
+ params.cu_seqlens_k = seqlens_k.data_ptr<int>();
190
+ params.h = num_heads;
191
+ params.h_h_k_ratio = num_heads / num_heads_k;
192
+ params.ngroups = ngroups;
193
+ params.is_causal = is_causal;
194
+ params.d = head_size;
195
+ params.d_v = head_size_v;
196
+ params.scale_softmax = softmax_scale;
197
+ params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
198
+ // Set the pointers and strides.
199
+ params.q_ptr = q.data_ptr();
200
+ params.k_ptr = kcache.data_ptr();
201
+ params.v_ptr = vcache.data_ptr();
202
+ params.o_ptr = out.data_ptr();
203
+ params.softmax_lse_ptr = softmax_lse.data_ptr();
204
+ // All stride are in elements, not bytes.
205
+ params.q_batch_stride = q.stride(0);
206
+ params.k_batch_stride = kcache.stride(0);
207
+ params.v_batch_stride = vcache.stride(0);
208
+ params.o_batch_stride = out.stride(0);
209
+ params.q_row_stride = q.stride(-3);
210
+ params.k_row_stride = kcache.stride(-3);
211
+ params.v_row_stride = vcache.stride(-3);
212
+ params.o_row_stride = out.stride(-3);
213
+ params.q_head_stride = q.stride(-2);
214
+ params.k_head_stride = kcache.stride(-2);
215
+ params.v_head_stride = vcache.stride(-2);
216
+ params.o_head_stride = out.stride(-2);
217
+
218
+ params.block_table = block_table.data_ptr<int>();
219
+ params.block_table_batch_stride = block_table.stride(0);
220
+ params.page_block_size = page_block_size;
221
+
222
+ TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
223
+ TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
224
+ CHECK_DEVICE(tile_scheduler_metadata);
225
+ CHECK_CONTIGUOUS(tile_scheduler_metadata);
226
+ params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
227
+ params.num_sm_parts = tile_scheduler_metadata.size(0);
228
+ TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
229
+ CHECK_DEVICE(num_splits);
230
+ CHECK_CONTIGUOUS(num_splits);
231
+ params.num_splits_ptr = num_splits.data_ptr<int>();
232
+
233
+ at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat));
234
+ at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat));
235
+ params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
236
+ params.oaccum_ptr = out_accum.data_ptr();
237
+
238
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
239
+ TORCH_CHECK(head_size == 576);
240
+
241
+ if (q_dtype == torch::kBFloat16) {
242
+ run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
243
+ }
244
+ #ifndef FLASH_MLA_DISABLE_FP16
245
+ else if (q_dtype == torch::kHalf) {
246
+ run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(params, stream);
247
+ }
248
+ #endif
249
+ else {
250
+ TORCH_CHECK(false, "Unsupported tensor dtype for query");
251
+ }
252
+
253
+ out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
254
+ .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
255
+ softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3)
256
+ .reshape({batch_size, num_heads_ori, seqlen_q_ori});
257
+
258
+ return {out, softmax_lse};
259
+ }
flash_mla/named_barrier.h ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "cutlass/barrier.h"
4
+
5
+ namespace flash {
6
+
7
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
8
+ // Enumerates the reserved named barriers to avoid potential conflicts
9
+
10
+ enum class NamedBarriers {
11
+ SReady = 1,
12
+ SoftmaxReady = 2,
13
+ };
14
+
15
+ } // flash
flash_mla/softmax.h ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h
2
+
3
+ #pragma once
4
+
5
+ #include <cmath>
6
+
7
+ #include <cute/tensor.hpp>
8
+ #include <cutlass/numeric_types.h>
9
+
10
+ #include "utils.h"
11
+
12
+ namespace flash {
13
+
14
+ using namespace cute;
15
+
16
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
17
+
18
+ template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
19
+ __device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
20
+ static_assert(Layout0::rank == 2, "Only support 2D Tensor");
21
+ static_assert(Layout1::rank == 1, "Only support 1D Tensor");
22
+ CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
23
+ #pragma unroll
24
+ for (int mi = 0; mi < size<0>(tensor); mi++) {
25
+ summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
26
+ #pragma unroll
27
+ for (int ni = 1; ni < size<1>(tensor); ni++) {
28
+ summary(mi) = op(summary(mi), tensor(mi, ni));
29
+ }
30
+ }
31
+ }
32
+
33
+ template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
34
+ __device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
35
+ CUTE_STATIC_ASSERT_V(size(dst) == size(src));
36
+ #pragma unroll
37
+ for (int i = 0; i < size(dst); i++){
38
+ dst(i) = Allreduce<4>::run(src(i), op);
39
+ }
40
+ }
41
+
42
+ template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
43
+ __device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
44
+ thread_reduce_<zero_init>(tensor, summary, op);
45
+ quad_allreduce_(summary, summary, op);
46
+ }
47
+
48
+ template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
49
+ __device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
50
+ MaxOp<float> max_op;
51
+ reduce_<zero_init>(tensor, max, max_op);
52
+ }
53
+
54
+ template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
55
+ __device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
56
+ SumOp<float> sum_op;
57
+ thread_reduce_<zero_init>(tensor, sum, sum_op);
58
+ }
59
+
60
+ // Apply the exp to all the elements.
61
+ template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
62
+ __forceinline__ __device__ auto scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
63
+ static_assert(Layout0::rank == 2, "Only support 2D Tensor");
64
+ static_assert(Layout1::rank == 1, "Only support 1D Tensor");
65
+ CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
66
+ #pragma unroll
67
+ for (int mi = 0; mi < size<0>(tensor); ++mi) {
68
+ // If max is -inf, then all elements must have been -inf (possibly due to masking).
69
+ // We don't want (-inf - (-inf)) since that would give NaN.
70
+ // If we don't have float around M_LOG2E the multiplication is done in fp64.
71
+ const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
72
+ #pragma unroll
73
+ for (int ni = 0; ni < size<1>(tensor); ++ni) {
74
+ // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
75
+ // max * log_2(e)) This allows the compiler to use the ffma
76
+ // instruction instead of fadd and fmul separately.
77
+ // The following macro will disable the use of fma.
78
+ // See: https://github.com/pytorch/pytorch/issues/121558 for more details
79
+ // This macro is set in PyTorch and not FlashAttention
80
+ #ifdef UNFUSE_FMA
81
+ tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
82
+ #else
83
+ tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
84
+ #endif
85
+ }
86
+ }
87
+ return tensor;
88
+ }
89
+
90
+ // Apply the exp to all the elements.
91
+ template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
92
+ __forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
93
+ static_assert(Layout0::rank == 2, "Only support 2D Tensor");
94
+ static_assert(Layout1::rank == 1, "Only support 1D Tensor");
95
+ CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
96
+ #pragma unroll
97
+ for (int mi = 0; mi < size<0>(tensor); ++mi) {
98
+ MaxOp<float> max_op;
99
+ max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
100
+ #pragma unroll
101
+ for (int ni = 1; ni < size<1>(tensor); ni++) {
102
+ max(mi) = max_op(max(mi), tensor(mi, ni));
103
+ }
104
+ max(mi) = Allreduce<4>::run(max(mi), max_op);
105
+ // If max is -inf, then all elements must have been -inf (possibly due to masking).
106
+ // We don't want (-inf - (-inf)) since that would give NaN.
107
+ const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
108
+ sum(mi) = 0;
109
+ #pragma unroll
110
+ for (int ni = 0; ni < size<1>(tensor); ++ni) {
111
+ // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
112
+ // max * log_2(e)) This allows the compiler to use the ffma
113
+ // instruction instead of fadd and fmul separately.
114
+ tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
115
+ sum(mi) += tensor(mi, ni);
116
+ }
117
+ SumOp<float> sum_op;
118
+ sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
119
+ }
120
+ }
121
+
122
+ template<typename Tensor0, typename Tensor1>
123
+ __forceinline__ __device__ void rescale_o(Tensor0 &acc_o, Tensor1 &scale_o) {
124
+ // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
125
+ Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
126
+ #pragma unroll
127
+ for (int mi = 0; mi < size(scale_o); ++mi) {
128
+ #pragma unroll
129
+ for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale_o(mi); }
130
+ }
131
+ }
132
+
133
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
134
+
135
+ template <int kNRows>
136
+ struct Softmax {
137
+
138
+ using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
139
+ TensorT row_max, row_sum;
140
+
141
+ __forceinline__ __device__ Softmax() {};
142
+
143
+ template<bool Is_first, bool Check_inf=false, typename Tensor0>
144
+ __forceinline__ __device__ TensorT softmax(Tensor0 &acc_s, float softmax_scale_log2) {
145
+ // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
146
+ Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
147
+ static_assert(decltype(size<0>(scores))::value == kNRows);
148
+ TensorT scale_o;
149
+ clear(scale_o);
150
+ if (Is_first) {
151
+ flash::template reduce_max</*zero_init=*/true>(scores, row_max);
152
+ flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
153
+ flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
154
+ } else {
155
+ Tensor scores_max_prev = make_fragment_like(row_max);
156
+ cute::copy(row_max, scores_max_prev);
157
+ flash::template reduce_max</*zero_init=*/false>(scores, row_max);
158
+ // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
159
+ #pragma unroll
160
+ for (int mi = 0; mi < size(row_max); ++mi) {
161
+ float scores_max_cur = !Check_inf
162
+ ? row_max(mi)
163
+ : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
164
+ float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
165
+ scale_o(mi) = scores_scale;
166
+ row_sum(mi) *= scores_scale;
167
+ }
168
+ flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
169
+ // We don't do the reduce across threads here since we don't need to use the row_sum.
170
+ // We do that reduce at the end when we need to normalize the softmax.
171
+ flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
172
+ }
173
+ return scale_o;
174
+ };
175
+
176
+ template<bool Is_dropout=false, bool Split=false, typename Tensor0>
177
+ __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
178
+ SumOp<float> sum_op;
179
+ quad_allreduce_(row_sum, row_sum, sum_op);
180
+ TensorT lse = make_fragment_like(row_sum);
181
+ // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
182
+ Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
183
+ static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
184
+ #pragma unroll
185
+ for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
186
+ float sum = row_sum(mi);
187
+ float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
188
+ lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
189
+ float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
190
+ #pragma unroll
191
+ for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
192
+ }
193
+ return lse;
194
+ };
195
+ };
196
+
197
+ } // namespace flash
flash_mla/static_switch.h ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #define CHECK_CUDA(call) \
4
+ do { \
5
+ cudaError_t status_ = call; \
6
+ if (status_ != cudaSuccess) { \
7
+ fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
8
+ exit(1); \
9
+ } \
10
+ } while(0)
11
+
12
+ #define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
13
+
14
+
15
+ #define FLASH_ASSERT(cond) \
16
+ do { \
17
+ if (not (cond)) { \
18
+ fprintf(stderr, "Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \
19
+ exit(1); \
20
+ } \
21
+ } while(0)
22
+
23
+
24
+ #define FLASH_DEVICE_ASSERT(cond) \
25
+ do { \
26
+ if (not (cond)) { \
27
+ printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \
28
+ asm("trap;"); \
29
+ } \
30
+ } while(0)
31
+
32
+
33
+ #define BOOL_SWITCH(COND, CONST_NAME, ...) \
34
+ [&] { \
35
+ if (COND) { \
36
+ constexpr static bool CONST_NAME = true; \
37
+ return __VA_ARGS__(); \
38
+ } else { \
39
+ constexpr static bool CONST_NAME = false; \
40
+ return __VA_ARGS__(); \
41
+ } \
42
+ }()
43
+
44
+
45
+ #define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \
46
+ [&] { \
47
+ if (NUM_SPLITS <= 32) { \
48
+ constexpr static int NAME = 32; \
49
+ return __VA_ARGS__(); \
50
+ } else if (NUM_SPLITS <= 64) { \
51
+ constexpr static int NAME = 64; \
52
+ return __VA_ARGS__(); \
53
+ } else if (NUM_SPLITS <= 96) { \
54
+ constexpr static int NAME = 96; \
55
+ return __VA_ARGS__(); \
56
+ } else if (NUM_SPLITS <= 128) { \
57
+ constexpr static int NAME = 128; \
58
+ return __VA_ARGS__(); \
59
+ } else if (NUM_SPLITS <= 160) { \
60
+ constexpr static int NAME = 160; \
61
+ return __VA_ARGS__(); \
62
+ } else { \
63
+ FLASH_ASSERT(false); \
64
+ } \
65
+ }()
flash_mla/utils.h ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h
2
+
3
+ #pragma once
4
+
5
+ #include <assert.h>
6
+ #include <stdint.h>
7
+ #include <stdlib.h>
8
+
9
+ #include <cuda_bf16.h>
10
+
11
+ #include <cute/tensor.hpp>
12
+
13
+ #include <cutlass/array.h>
14
+ #include <cutlass/cutlass.h>
15
+ #include <cutlass/numeric_conversion.h>
16
+ #include <cutlass/numeric_types.h>
17
+
18
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
19
+
20
+ namespace flash {
21
+
22
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
23
+
24
+ template<typename T>
25
+ struct MaxOp {
26
+ __device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
27
+ };
28
+
29
+ template <>
30
+ struct MaxOp<float> {
31
+ // This is slightly faster
32
+ __device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
33
+ };
34
+
35
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
36
+
37
+ template<typename T>
38
+ struct SumOp {
39
+ __device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
40
+ };
41
+
42
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
43
+
44
+ template<int THREADS>
45
+ struct Allreduce {
46
+ static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
47
+ template<typename T, typename Operator>
48
+ static __device__ __forceinline__ T run(T x, Operator &op) {
49
+ constexpr int OFFSET = THREADS / 2;
50
+ x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
51
+ return Allreduce<OFFSET>::run(x, op);
52
+ }
53
+ };
54
+
55
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
56
+
57
+ template<>
58
+ struct Allreduce<2> {
59
+ template<typename T, typename Operator>
60
+ static __device__ __forceinline__ T run(T x, Operator &op) {
61
+ x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
62
+ return x;
63
+ }
64
+ };
65
+
66
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
67
+
68
+ template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
69
+ __forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
70
+ constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
71
+ // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
72
+ if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
73
+ warpgroup_fence_operand(tCrC);
74
+ if constexpr (arrive) {
75
+ warpgroup_arrive();
76
+ }
77
+ if constexpr (zero_init) {
78
+ tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
79
+ // Unroll the K mode manually to set scale D to 1
80
+ CUTLASS_PRAGMA_UNROLL
81
+ for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
82
+ cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
83
+ tiled_mma.accumulate_ = GMMA::ScaleOut::One;
84
+ }
85
+ } else {
86
+ // cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
87
+ // Unroll the K mode manually to set scale D to 1
88
+ CUTLASS_PRAGMA_UNROLL
89
+ for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
90
+ cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
91
+ tiled_mma.accumulate_ = GMMA::ScaleOut::One;
92
+ }
93
+ }
94
+ if constexpr (commit) {
95
+ warpgroup_commit_batch();
96
+ }
97
+ if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
98
+ warpgroup_fence_operand(tCrC);
99
+ if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
100
+ }
101
+
102
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
103
+
104
+ // For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
105
+ // For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
106
+ template<bool Transposed=false, typename Layout0>
107
+ __forceinline__ __device__ auto convert_layout_acc_rowcol(Layout0 acc_layout) {
108
+ if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
109
+ static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
110
+ static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
111
+ static_assert(decltype(rank(acc_layout))::value == 3);
112
+ auto l = acc_layout;
113
+ if constexpr (!Transposed) {
114
+ return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
115
+ } else {
116
+ return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
117
+ }
118
+
119
+ } else { // SM80
120
+ static_assert(decltype(size<0>(acc_layout))::value == 4);
121
+ static_assert(decltype(rank(acc_layout))::value == 3);
122
+ auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
123
+ if constexpr (!Transposed) {
124
+ return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
125
+ } else {
126
+ return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
127
+ }
128
+ }
129
+ };
130
+
131
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
132
+
133
+ // For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
134
+ // if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
135
+ // For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
136
+ // For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N))
137
+ template<typename MMA_Traits, typename Layout0>
138
+ __forceinline__ __device__ auto convert_layout_acc_Aregs(Layout0 acc_layout) {
139
+ using X = Underscore;
140
+ if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
141
+ static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
142
+ static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
143
+ static_assert(decltype(rank(acc_layout))::value == 3);
144
+ static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
145
+ if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) {
146
+ auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16))
147
+ return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
148
+ } else {
149
+ static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1);
150
+ static_assert(decltype(stride<0, 0>(acc_layout))::value == 1);
151
+ static_assert(decltype(stride<0, 1>(acc_layout))::value == 2);
152
+ auto l = logical_divide(get<0, 2>(acc_layout), Tile<Layout<Shape<_2, _2>>>{}); // (((2, 2), N / 32))
153
+ // This combines the first two modes (<0, 0> and <0, 1>) into one mode.
154
+ // Will require register shuffling later to be correct.
155
+ return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)),
156
+ get<1>(acc_layout),
157
+ coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N)
158
+ // This combination is right but doesn't work with register shuffling.
159
+ // return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)),
160
+ // get<1>(acc_layout),
161
+ // coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
162
+ }
163
+ } else { // SM80
164
+ static_assert(decltype(size<0>(acc_layout))::value == 4);
165
+ static_assert(decltype(rank(acc_layout))::value == 3);
166
+ constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{});
167
+ static_assert(mma_shape_K == 8 || mma_shape_K == 16);
168
+ if constexpr (mma_shape_K == 8) {
169
+ return acc_layout;
170
+ } else {
171
+ auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
172
+ return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
173
+ }
174
+ }
175
+ };
176
+
177
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
178
+
179
+ template <typename To_type, typename Engine, typename Layout>
180
+ __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
181
+ using From_type = typename Engine::value_type;
182
+ constexpr int numel = decltype(size(tensor))::value;
183
+ cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
184
+ // HACK: this requires tensor to be "contiguous"
185
+ auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
186
+ return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
187
+ }
188
+
189
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
190
+
191
+ // Blocks until all but N previous cp.async.commit_group operations have committed.
192
+ // This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
193
+ // (which is equivalent to commit_group then wait_group 0).
194
+ // Instead we just call cp.async.wait_group 0, which is slightly faster.
195
+ // https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
196
+ template <int N>
197
+ CUTE_HOST_DEVICE
198
+ void cp_async_wait() {
199
+ #if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
200
+ asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
201
+ #endif
202
+ }
203
+
204
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
205
+
206
+ template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
207
+ typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
208
+ typename Engine2, typename Layout2, typename Engine3, typename Layout3>
209
+ __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
210
+ Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
211
+ Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
212
+ CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
213
+ CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
214
+ CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
215
+ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
216
+ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
217
+ // There's no case where !Clear_OOB_K && Clear_OOB_MN
218
+ static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
219
+ #pragma unroll
220
+ for (int m = 0; m < size<1>(S); ++m) {
221
+ if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
222
+ #pragma unroll
223
+ for (int k = 0; k < size<2>(S); ++k) {
224
+ if (Is_even_K || predicate_K(k)) {
225
+ cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
226
+ } else if (Clear_OOB_K) {
227
+ cute::clear(D(_, m, k));
228
+ }
229
+ }
230
+ } else if (Clear_OOB_MN) {
231
+ cute::clear(D(_, m, _));
232
+ }
233
+ }
234
+ }
235
+
236
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
237
+
238
+ } // namespace flash
tests/__init__.py ADDED
File without changes
tests/test_flash_mla.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import torch.nn.functional as F
4
+
5
+ import flash_mla
6
+
7
+ # TODO: revise to use the same test as the original code
8
+
9
+
10
+ def test_flash_mla():
11
+ # b = 128
12
+ # s_q = 4096
13
+ # mean_sk = 8192
14
+ # h_q = 16
15
+ # h_kv = 1
16
+ # d = 576
17
+ # dv = 512
18
+
19
+ b = 16
20
+ s_q = 16
21
+ mean_sk = 16
22
+ h_q = 16
23
+ h_kv = 1
24
+ d = 576
25
+ dv = 512
26
+
27
+
28
+ causal = True
29
+ varlen = False
30
+
31
+ print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}")
32
+
33
+ cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
34
+ if varlen:
35
+ for i in range(b):
36
+ cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
37
+ total_seqlens = cache_seqlens.sum().item()
38
+ mean_seqlens = cache_seqlens.float().mean().int().item()
39
+ max_seqlen = cache_seqlens.max().item()
40
+ # TODO: avoid triton from original code
41
+ # max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
42
+ print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
43
+ max_seqlen_pad = max_seqlen + 255 & ~255 # round up to multiple of 256
44
+ q = torch.randn(b, s_q, h_q, d)
45
+ block_size = 64
46
+ block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(
47
+ b, max_seqlen_pad // block_size
48
+ )
49
+ blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
50
+ print(blocked_k.shape)
51
+ for i in range(b):
52
+ blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = float(
53
+ "nan"
54
+ )
55
+ blocked_v = blocked_k[..., :dv]
56
+ print(blocked_k.shape, blocked_v.shape)
57
+
58
+ cache_seqlens = cache_seqlens.to("cuda")
59
+
60
+ tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata(
61
+ seqlens_k=cache_seqlens,
62
+ #
63
+ s_q=s_q * h_q // h_kv,
64
+ h_kv=h_kv,
65
+ )
66
+ print(tile_scheduler_metadata, num_splits)
67
+
68
+ # TODO: update to expect the correct output
69
+ assert False
torch-ext/flash_mla/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ._ops import ops
4
+
5
+
6
+ def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int):
7
+ return ops.get_mla_metadata(seqlens_k, s_q, h_kv)
8
+
9
+
10
+ def mha_fwd_kvcache_mla(
11
+ q: torch.Tensor,
12
+ kcache: torch.Tensor,
13
+ vcache_: torch.Tensor,
14
+ head_size_v: int,
15
+ seqlens_k: torch.Tensor,
16
+ block_table: torch.Tensor,
17
+ softmax_scale: float,
18
+ is_causal_: bool,
19
+ tile_scheduler_metadata: torch.Tensor,
20
+ num_splits: torch.Tensor,
21
+ ) -> torch.Tensor:
22
+ # TODO: remove when resolved
23
+ unknown_param = 0
24
+ return ops.mha_fwd_kvcache_mla(
25
+ q,
26
+ kcache,
27
+ vcache_,
28
+ head_size_v,
29
+ seqlens_k,
30
+ block_table,
31
+ softmax_scale,
32
+ is_causal_,
33
+ tile_scheduler_metadata,
34
+ num_splits,
35
+ unknown_param,
36
+ )
torch-ext/torch_binding.cpp ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/library.h>
2
+
3
+ #include "registration.h"
4
+ #include "torch_binding.h"
5
+
6
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
+ ops.def("get_mla_metadata(Tensor! seqlens_k, int num_heads_per_head_k, int num_heads_k) -> Tensor[]");
8
+ ops.impl("get_mla_metadata", torch::kCUDA, &get_mla_metadata);
9
+
10
+ // TOOD: remove last unknown_param when resolved
11
+ ops.def("mha_fwd_kvcache_mla(Tensor! q, Tensor! kcache, Tensor! vcache_, int head_size_v, Tensor! seqlens_k, Tensor! block_table, float softmax_scale, bool is_causal_, Tensor! tile_scheduler_metadata, Tensor! num_splits, int unknown_param) -> Tensor[]");
12
+ ops.impl("mha_fwd_kvcache_mla", torch::kCUDA, &mha_fwd_kvcache_mla);
13
+ }
14
+
15
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/torch.h>
4
+
5
+ std::vector<torch::Tensor>
6
+ get_mla_metadata(
7
+ torch::Tensor &seqlens_k,
8
+ const int64_t num_heads_per_head_k,
9
+ const int64_t num_heads_k
10
+ );
11
+
12
+ std::vector<torch::Tensor>
13
+ mha_fwd_kvcache_mla(
14
+ torch::Tensor &q,
15
+ const torch::Tensor &kcache,
16
+
17
+ // TODO: fix for optional
18
+ // std::optional<torch::Tensor> &vcache_,
19
+
20
+ const torch::Tensor &vcache_,
21
+ const int64_t head_size_v,
22
+ const torch::Tensor &seqlens_k,
23
+ const torch::Tensor &block_table,
24
+
25
+ // TODO:should be float
26
+ const double softmax_scale,
27
+
28
+ // TODO: fix for mutable bool
29
+ const bool is_causal_,
30
+
31
+ const torch::Tensor &tile_scheduler_metadata,
32
+ const torch::Tensor &num_splits,
33
+
34
+ // TODO: remove when resolved
35
+ const int64_t unknown_param = 0
36
+ );