Kernels
danieldk HF Staff commited on
Commit
17604a0
·
verified ·
1 Parent(s): efc5faa

Build uploaded using `kernels` (batch 8/10).

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. build/torch29-cxx11-cu128-x86_64-linux/metadata.json +8 -2
  3. build/torch29-cxx11-cu129-x86_64-linux/__init__.py +684 -0
  4. build/torch29-cxx11-cu129-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so +3 -0
  5. build/torch29-cxx11-cu129-x86_64-linux/_ops.py +9 -0
  6. build/torch29-cxx11-cu129-x86_64-linux/deep_gemm/__init__.py +26 -0
  7. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/common/cute_tie.cuh +48 -0
  8. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/common/epilogue_utils.cuh +27 -0
  9. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/common/reduction.cuh +44 -0
  10. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/common/scheduler.cuh +288 -0
  11. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/common/sm100_utils.cuh +266 -0
  12. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/common/sm90_utils.cuh +332 -0
  13. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/common/tma_utils.cuh +116 -0
  14. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/common/types.hpp +41 -0
  15. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/common/utils.cuh +183 -0
  16. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm100_bf16_gemm.cuh +482 -0
  17. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh +265 -0
  18. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +563 -0
  19. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh +404 -0
  20. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh +398 -0
  21. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh +345 -0
  22. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm90_bf16_gemm.cuh +381 -0
  23. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh +174 -0
  24. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh +349 -0
  25. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +440 -0
  26. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh +329 -0
  27. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh +413 -0
  28. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh +287 -0
  29. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/smxx_clean_logits.cuh +67 -0
  30. build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/smxx_layout.cuh +176 -0
  31. build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/03_visualize_layout/options.h +121 -0
  32. build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/03_visualize_layout/register_layout.h +59 -0
  33. build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/03_visualize_layout/visualize_layout.h +383 -0
  34. build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h +719 -0
  35. build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_gemm_run.h +763 -0
  36. build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h +450 -0
  37. build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h +749 -0
  38. build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h +798 -0
  39. build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/device/b2b_gemm.h +352 -0
  40. build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h +300 -0
  41. build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h +811 -0
  42. build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm_grouped_problem_visitor.h +157 -0
  43. build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h +521 -0
  44. build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop.h +94 -0
  45. build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm75.h +749 -0
  46. build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h +740 -0
  47. build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h +817 -0
  48. build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h +804 -0
  49. build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h +503 -0
  50. build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h +384 -0
.gitattributes CHANGED
@@ -9,3 +9,4 @@ build/torch210-cxx11-cu128-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=l
9
  build/torch210-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
10
  build/torch29-cxx11-cu126-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
11
  build/torch29-cxx11-cu128-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
 
 
9
  build/torch210-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
10
  build/torch29-cxx11-cu126-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
11
  build/torch29-cxx11-cu128-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
12
+ build/torch29-cxx11-cu129-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
build/torch29-cxx11-cu128-x86_64-linux/metadata.json CHANGED
@@ -1,5 +1,11 @@
1
  {
2
  "version": 1,
3
  "license": "MIT",
4
- "python-depends": []
5
- }
 
 
 
 
 
 
 
1
  {
2
  "version": 1,
3
  "license": "MIT",
4
+ "python-depends": [],
5
+ "backend": {
6
+ "type": "cuda",
7
+ "archs": [
8
+ "9.0a"
9
+ ]
10
+ }
11
+ }
build/torch29-cxx11-cu129-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import torch
4
+
5
+ # Import the compiled extension
6
+ from ._ops import ops
7
+ from . import utils
8
+
9
+ __version__ = "2.3.0"
10
+
11
+
12
+ # Runtime
13
+
14
+
15
+ def set_num_sms(num_sms: int):
16
+ ops.set_num_sms(num_sms)
17
+
18
+
19
+ def get_num_sms() -> int:
20
+ return ops.get_num_sms()
21
+
22
+
23
+ def set_tc_util(tc_util: int):
24
+ ops.set_tc_util(tc_util)
25
+
26
+
27
+ def get_tc_util() -> int:
28
+ return ops.get_tc_util()
29
+
30
+
31
+ def get_mk_alignment_for_contiguous_layout() -> int:
32
+ return ops.get_mk_alignment_for_contiguous_layout()
33
+
34
+
35
+ # Layout utilities
36
+
37
+
38
+ def get_tma_aligned_size(mn: int, element_size: int) -> int:
39
+ return ops.get_tma_aligned_size(mn, element_size).item()
40
+
41
+
42
+ def get_mn_major_tma_aligned_tensor(sf):
43
+ return ops.get_mn_major_tma_aligned_tensor(sf)
44
+
45
+
46
+ def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf):
47
+ return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
48
+
49
+
50
+ def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks):
51
+ ks_int = torch.tensor(ks, dtype=torch.int32, device="cpu")
52
+ return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
53
+ sf, ks_tensor, ks_int
54
+ )
55
+
56
+
57
+ def transform_sf_into_required_layout(
58
+ sf,
59
+ mn,
60
+ k,
61
+ recipe=None,
62
+ recipe_ab=None,
63
+ num_groups=None,
64
+ is_sfa=False,
65
+ disable_ue8m0_cast=False,
66
+ ):
67
+ has_recipe = recipe is not None
68
+ r0, r1, r2 = recipe if has_recipe else (0, 0, 0)
69
+ has_recipe_ab = recipe_ab is not None
70
+ rab0, rab1 = recipe_ab if has_recipe_ab else (0, 0)
71
+ has_ng = num_groups is not None
72
+ ng = num_groups if has_ng else 0
73
+ return ops.transform_sf_into_required_layout(
74
+ sf,
75
+ mn,
76
+ k,
77
+ r0,
78
+ r1,
79
+ r2,
80
+ has_recipe,
81
+ rab0,
82
+ rab1,
83
+ has_recipe_ab,
84
+ ng,
85
+ has_ng,
86
+ is_sfa,
87
+ disable_ue8m0_cast,
88
+ )
89
+
90
+
91
+ # Aliases for contiguous layout alignment
92
+ get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
93
+ get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
94
+
95
+
96
+ # Helper to flatten recipe args
97
+
98
+
99
+ def _flatten_recipe(recipe, recipe_a=None, recipe_b=None):
100
+ has_recipe = recipe is not None
101
+ r0, r1, r2 = recipe if has_recipe else (0, 0, 0)
102
+ has_ra = recipe_a is not None
103
+ ra0, ra1 = recipe_a if has_ra else (0, 0)
104
+ has_rb = recipe_b is not None
105
+ rb0, rb1 = recipe_b if has_rb else (0, 0)
106
+ return r0, r1, r2, has_recipe, ra0, ra1, has_ra, rb0, rb1, has_rb
107
+
108
+
109
+ # FP8/FP4 GEMM ops
110
+
111
+
112
+ def fp8_fp4_gemm_nt(
113
+ a,
114
+ b,
115
+ d,
116
+ c=None,
117
+ recipe=None,
118
+ recipe_a=None,
119
+ recipe_b=None,
120
+ compiled_dims="nk",
121
+ disable_ue8m0_cast=False,
122
+ ):
123
+ a_data, a_sf = a
124
+ b_data, b_sf = b
125
+ r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe(
126
+ recipe, recipe_a, recipe_b
127
+ )
128
+ ops.fp8_fp4_gemm_nt(
129
+ a_data,
130
+ a_sf,
131
+ b_data,
132
+ b_sf,
133
+ d,
134
+ c,
135
+ r0,
136
+ r1,
137
+ r2,
138
+ hr,
139
+ ra0,
140
+ ra1,
141
+ hra,
142
+ rb0,
143
+ rb1,
144
+ hrb,
145
+ compiled_dims,
146
+ disable_ue8m0_cast,
147
+ )
148
+
149
+
150
+ def fp8_fp4_gemm_nn(
151
+ a,
152
+ b,
153
+ d,
154
+ c=None,
155
+ recipe=None,
156
+ recipe_a=None,
157
+ recipe_b=None,
158
+ compiled_dims="nk",
159
+ disable_ue8m0_cast=False,
160
+ ):
161
+ a_data, a_sf = a
162
+ b_data, b_sf = b
163
+ r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe(
164
+ recipe, recipe_a, recipe_b
165
+ )
166
+ ops.fp8_fp4_gemm_nn(
167
+ a_data,
168
+ a_sf,
169
+ b_data,
170
+ b_sf,
171
+ d,
172
+ c,
173
+ r0,
174
+ r1,
175
+ r2,
176
+ hr,
177
+ ra0,
178
+ ra1,
179
+ hra,
180
+ rb0,
181
+ rb1,
182
+ hrb,
183
+ compiled_dims,
184
+ disable_ue8m0_cast,
185
+ )
186
+
187
+
188
+ def fp8_fp4_gemm_tn(
189
+ a,
190
+ b,
191
+ d,
192
+ c=None,
193
+ recipe=None,
194
+ recipe_a=None,
195
+ recipe_b=None,
196
+ compiled_dims="mn",
197
+ disable_ue8m0_cast=False,
198
+ ):
199
+ a_data, a_sf = a
200
+ b_data, b_sf = b
201
+ r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe(
202
+ recipe, recipe_a, recipe_b
203
+ )
204
+ ops.fp8_fp4_gemm_tn(
205
+ a_data,
206
+ a_sf,
207
+ b_data,
208
+ b_sf,
209
+ d,
210
+ c,
211
+ r0,
212
+ r1,
213
+ r2,
214
+ hr,
215
+ ra0,
216
+ ra1,
217
+ hra,
218
+ rb0,
219
+ rb1,
220
+ hrb,
221
+ compiled_dims,
222
+ disable_ue8m0_cast,
223
+ )
224
+
225
+
226
+ def fp8_fp4_gemm_tt(
227
+ a,
228
+ b,
229
+ d,
230
+ c=None,
231
+ recipe=None,
232
+ recipe_a=None,
233
+ recipe_b=None,
234
+ compiled_dims="mn",
235
+ disable_ue8m0_cast=False,
236
+ ):
237
+ a_data, a_sf = a
238
+ b_data, b_sf = b
239
+ r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe(
240
+ recipe, recipe_a, recipe_b
241
+ )
242
+ ops.fp8_fp4_gemm_tt(
243
+ a_data,
244
+ a_sf,
245
+ b_data,
246
+ b_sf,
247
+ d,
248
+ c,
249
+ r0,
250
+ r1,
251
+ r2,
252
+ hr,
253
+ ra0,
254
+ ra1,
255
+ hra,
256
+ rb0,
257
+ rb1,
258
+ hrb,
259
+ compiled_dims,
260
+ disable_ue8m0_cast,
261
+ )
262
+
263
+
264
+ # FP8 aliases (same as FP8/FP4)
265
+ fp8_gemm_nt = fp8_fp4_gemm_nt
266
+ fp8_gemm_nn = fp8_fp4_gemm_nn
267
+ fp8_gemm_tn = fp8_fp4_gemm_tn
268
+ fp8_gemm_tt = fp8_fp4_gemm_tt
269
+
270
+
271
+ # M-grouped FP8/FP4 GEMM ops
272
+
273
+
274
+ def m_grouped_fp8_fp4_gemm_nt_contiguous(
275
+ a,
276
+ b,
277
+ d,
278
+ grouped_layout,
279
+ recipe=None,
280
+ recipe_a=None,
281
+ recipe_b=None,
282
+ compiled_dims="nk",
283
+ disable_ue8m0_cast=False,
284
+ use_psum_layout=False,
285
+ expected_m_for_psum_layout=None,
286
+ ):
287
+ a_data, a_sf = a
288
+ b_data, b_sf = b
289
+ r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe(
290
+ recipe, recipe_a, recipe_b
291
+ )
292
+ has_em = expected_m_for_psum_layout is not None
293
+ em = expected_m_for_psum_layout if has_em else 0
294
+ ops.m_grouped_fp8_fp4_gemm_nt_contiguous(
295
+ a_data,
296
+ a_sf,
297
+ b_data,
298
+ b_sf,
299
+ d,
300
+ grouped_layout,
301
+ r0,
302
+ r1,
303
+ r2,
304
+ hr,
305
+ ra0,
306
+ ra1,
307
+ hra,
308
+ rb0,
309
+ rb1,
310
+ hrb,
311
+ compiled_dims,
312
+ disable_ue8m0_cast,
313
+ use_psum_layout,
314
+ em,
315
+ has_em,
316
+ )
317
+
318
+
319
+ def m_grouped_fp8_fp4_gemm_nn_contiguous(
320
+ a,
321
+ b,
322
+ d,
323
+ grouped_layout,
324
+ recipe=None,
325
+ recipe_a=None,
326
+ recipe_b=None,
327
+ compiled_dims="nk",
328
+ disable_ue8m0_cast=False,
329
+ use_psum_layout=False,
330
+ ):
331
+ a_data, a_sf = a
332
+ b_data, b_sf = b
333
+ r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe(
334
+ recipe, recipe_a, recipe_b
335
+ )
336
+ ops.m_grouped_fp8_fp4_gemm_nn_contiguous(
337
+ a_data,
338
+ a_sf,
339
+ b_data,
340
+ b_sf,
341
+ d,
342
+ grouped_layout,
343
+ r0,
344
+ r1,
345
+ r2,
346
+ hr,
347
+ ra0,
348
+ ra1,
349
+ hra,
350
+ rb0,
351
+ rb1,
352
+ hrb,
353
+ compiled_dims,
354
+ disable_ue8m0_cast,
355
+ use_psum_layout,
356
+ )
357
+
358
+
359
+ def m_grouped_fp8_fp4_gemm_nt_masked(
360
+ a,
361
+ b,
362
+ d,
363
+ masked_m,
364
+ expected_m,
365
+ recipe=None,
366
+ recipe_a=None,
367
+ recipe_b=None,
368
+ compiled_dims="nk",
369
+ disable_ue8m0_cast=False,
370
+ ):
371
+ a_data, a_sf = a
372
+ b_data, b_sf = b
373
+ r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe(
374
+ recipe, recipe_a, recipe_b
375
+ )
376
+ ops.m_grouped_fp8_fp4_gemm_nt_masked(
377
+ a_data,
378
+ a_sf,
379
+ b_data,
380
+ b_sf,
381
+ d,
382
+ masked_m,
383
+ expected_m,
384
+ r0,
385
+ r1,
386
+ r2,
387
+ hr,
388
+ ra0,
389
+ ra1,
390
+ hra,
391
+ rb0,
392
+ rb1,
393
+ hrb,
394
+ compiled_dims,
395
+ disable_ue8m0_cast,
396
+ )
397
+
398
+
399
+ # M-grouped FP8 aliases
400
+ m_grouped_fp8_gemm_nt_contiguous = m_grouped_fp8_fp4_gemm_nt_contiguous
401
+ m_grouped_fp8_gemm_nn_contiguous = m_grouped_fp8_fp4_gemm_nn_contiguous
402
+ m_grouped_fp8_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked
403
+
404
+ # Legacy aliases
405
+ fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked
406
+
407
+
408
+ # K-grouped FP8 GEMM ops
409
+
410
+
411
+ def k_grouped_fp8_gemm_tn_contiguous(
412
+ a, b, d, ks, ks_tensor, c=None, recipe=(1, 1, 128), compiled_dims="mn"
413
+ ):
414
+ a_data, a_sf = a
415
+ b_data, b_sf = b
416
+ r0, r1, r2 = recipe
417
+ ops.k_grouped_fp8_gemm_tn_contiguous(
418
+ a_data, a_sf, b_data, b_sf, d, ks_tensor, c, r0, r1, r2, compiled_dims
419
+ )
420
+
421
+
422
+ def k_grouped_fp8_gemm_nt_contiguous(
423
+ a, b, d, ks, ks_tensor, c=None, recipe=(1, 1, 128), compiled_dims="mn"
424
+ ):
425
+ a_data, a_sf = a
426
+ b_data, b_sf = b
427
+ r0, r1, r2 = recipe
428
+ ops.k_grouped_fp8_gemm_nt_contiguous(
429
+ a_data, a_sf, b_data, b_sf, d, ks_tensor, c, r0, r1, r2, compiled_dims
430
+ )
431
+
432
+
433
+ # BF16 GEMM ops
434
+
435
+
436
+ def bf16_gemm_nt(a, b, d, c=None, compiled_dims="nk"):
437
+ ops.bf16_gemm_nt(a, b, d, c, compiled_dims)
438
+
439
+
440
+ def bf16_gemm_nn(a, b, d, c=None, compiled_dims="nk"):
441
+ ops.bf16_gemm_nn(a, b, d, c, compiled_dims)
442
+
443
+
444
+ def bf16_gemm_tn(a, b, d, c=None, compiled_dims="mn"):
445
+ ops.bf16_gemm_tn(a, b, d, c, compiled_dims)
446
+
447
+
448
+ def bf16_gemm_tt(a, b, d, c=None, compiled_dims="mn"):
449
+ ops.bf16_gemm_tt(a, b, d, c, compiled_dims)
450
+
451
+
452
+ # M-grouped BF16 GEMM ops
453
+
454
+
455
+ def m_grouped_bf16_gemm_nt_contiguous(
456
+ a,
457
+ b,
458
+ d,
459
+ grouped_layout,
460
+ compiled_dims="nk",
461
+ use_psum_layout=False,
462
+ expected_m_for_psum_layout=None,
463
+ ):
464
+ has_em = expected_m_for_psum_layout is not None
465
+ em = expected_m_for_psum_layout if has_em else 0
466
+ ops.m_grouped_bf16_gemm_nt_contiguous(
467
+ a, b, d, grouped_layout, compiled_dims, use_psum_layout, em, has_em
468
+ )
469
+
470
+
471
+ def m_grouped_bf16_gemm_nn_contiguous(
472
+ a, b, d, grouped_layout, compiled_dims="nk", use_psum_layout=False
473
+ ):
474
+ ops.m_grouped_bf16_gemm_nn_contiguous(
475
+ a, b, d, grouped_layout, compiled_dims, use_psum_layout
476
+ )
477
+
478
+
479
+ def m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m, compiled_dims="nk"):
480
+ ops.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m, compiled_dims)
481
+
482
+
483
+ # Legacy alias
484
+ bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked
485
+
486
+
487
+ # K-grouped BF16 GEMM ops
488
+
489
+
490
+ def k_grouped_bf16_gemm_tn_contiguous(
491
+ a, b, d, ks, ks_tensor, c=None, compiled_dims="mn"
492
+ ):
493
+ ops.k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks_tensor, c, compiled_dims)
494
+
495
+
496
+ # cuBLASLt GEMM ops
497
+
498
+
499
+ def cublaslt_gemm_nt(a, b, d, c=None):
500
+ ops.cublaslt_gemm_nt(a, b, d, c)
501
+
502
+
503
+ def cublaslt_gemm_nn(a, b, d, c=None):
504
+ ops.cublaslt_gemm_nn(a, b, d, c)
505
+
506
+
507
+ def cublaslt_gemm_tn(a, b, d, c=None):
508
+ ops.cublaslt_gemm_tn(a, b, d, c)
509
+
510
+
511
+ def cublaslt_gemm_tt(a, b, d, c=None):
512
+ ops.cublaslt_gemm_tt(a, b, d, c)
513
+
514
+
515
+ # Attention ops
516
+
517
+
518
+ def fp8_gemm_nt_skip_head_mid(
519
+ a, b, d, head_splits, recipe=None, compiled_dims="nk", disable_ue8m0_cast=False
520
+ ):
521
+ a_data, a_sf = a
522
+ b_data, b_sf = b
523
+ left, mid, right = head_splits
524
+ has_recipe = recipe is not None
525
+ r0, r1, r2 = recipe if has_recipe else (0, 0, 0)
526
+ ops.fp8_gemm_nt_skip_head_mid(
527
+ a_data,
528
+ a_sf,
529
+ b_data,
530
+ b_sf,
531
+ d,
532
+ left,
533
+ mid,
534
+ right,
535
+ r0,
536
+ r1,
537
+ r2,
538
+ has_recipe,
539
+ compiled_dims,
540
+ disable_ue8m0_cast,
541
+ )
542
+
543
+
544
+ def fp8_mqa_logits(
545
+ q,
546
+ kv,
547
+ weights,
548
+ cu_seq_len_k_start,
549
+ cu_seq_len_k_end,
550
+ clean_logits=True,
551
+ max_seqlen_k=0,
552
+ ):
553
+ kv_data, kv_sf = kv
554
+ return ops.fp8_mqa_logits(
555
+ q,
556
+ kv_data,
557
+ kv_sf,
558
+ weights,
559
+ cu_seq_len_k_start,
560
+ cu_seq_len_k_end,
561
+ clean_logits,
562
+ max_seqlen_k,
563
+ )
564
+
565
+
566
+ def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms):
567
+ return ops.get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms)
568
+
569
+
570
+ def fp8_paged_mqa_logits(
571
+ q,
572
+ kv_cache,
573
+ weights,
574
+ context_lens,
575
+ block_table,
576
+ schedule_meta,
577
+ max_context_len,
578
+ clean_logits=False,
579
+ ):
580
+ return ops.fp8_paged_mqa_logits(
581
+ q,
582
+ kv_cache,
583
+ weights,
584
+ context_lens,
585
+ block_table,
586
+ schedule_meta,
587
+ max_context_len,
588
+ clean_logits,
589
+ )
590
+
591
+
592
+ # Einsum ops
593
+
594
+
595
+ def einsum(expr, a, b, d, c=None, use_cublaslt=False):
596
+ ops.einsum(expr, a, b, d, c, use_cublaslt)
597
+
598
+
599
+ def fp8_einsum(expr, a, b, d, c=None, recipe=(1, 128, 128)):
600
+ a_data, a_sf = a
601
+ b_data, b_sf = b
602
+ r0, r1, r2 = recipe
603
+ ops.fp8_einsum(expr, a_data, a_sf, b_data, b_sf, d, c, r0, r1, r2)
604
+
605
+
606
+ # Hyperconnection ops
607
+
608
+
609
+ def tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits=None):
610
+ has_ns = num_splits is not None
611
+ ns = num_splits if has_ns else 0
612
+ ops.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, ns, has_ns)
613
+
614
+
615
+ # Initialize the C++ runtime
616
+
617
+
618
+ def _find_cuda_home() -> str:
619
+ cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
620
+ if cuda_home is None:
621
+ try:
622
+ with open(os.devnull, "w") as devnull:
623
+ nvcc = (
624
+ subprocess.check_output(["which", "nvcc"], stderr=devnull)
625
+ .decode()
626
+ .rstrip("\r\n")
627
+ )
628
+ cuda_home = os.path.dirname(os.path.dirname(nvcc))
629
+ except Exception:
630
+ cuda_home = "/usr/local/cuda"
631
+ if not os.path.exists(cuda_home):
632
+ cuda_home = None
633
+ assert cuda_home is not None, "Could not find CUDA installation"
634
+ return cuda_home
635
+
636
+
637
+ # Find the library root for JIT headers
638
+ # In development: use the repo's deep_gemm/ directory
639
+ # In installed wheel: use this package's directory
640
+ _lib_root = os.path.join(
641
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "deep_gemm"
642
+ )
643
+ if not os.path.isdir(os.path.join(_lib_root, "include")):
644
+ # Fallback: try the parent package
645
+ _lib_root = os.path.dirname(os.path.abspath(__file__))
646
+
647
+ _initialized = False
648
+
649
+ # Set DG_CUTLASS_INCLUDE for JIT kernel compilation (if not already set by user)
650
+ if "DG_CUTLASS_INCLUDE" not in os.environ:
651
+ _include = os.path.join(_lib_root, "include")
652
+ _cutlass_include_candidates = [
653
+ _include, # legacy layout: include/cutlass
654
+ os.path.join(_include, "third-party", "cutlass", "include"), # submodule layout
655
+ ]
656
+ for _cutlass_include in _cutlass_include_candidates:
657
+ if os.path.isdir(os.path.join(_cutlass_include, "cutlass")):
658
+ os.environ["DG_CUTLASS_INCLUDE"] = _cutlass_include
659
+ break
660
+ else:
661
+ # Fall back to nvidia-cutlass pip package
662
+ try:
663
+ import nvidia.cutlass as _nc
664
+ os.environ["DG_CUTLASS_INCLUDE"] = os.path.join(
665
+ os.path.dirname(_nc.__file__), "include"
666
+ )
667
+ except ImportError:
668
+ pass
669
+
670
+ def _ensure_initialized():
671
+ global _initialized
672
+ if _initialized:
673
+ return
674
+ _initialized = True
675
+ ops.init(_lib_root, _find_cuda_home())
676
+
677
+
678
+ # Try to initialize eagerly, but don't fail if CUDA is not found
679
+ # (e.g., during build-time import checks). init() will be called
680
+ # lazily on first actual kernel use.
681
+ try:
682
+ _ensure_initialized()
683
+ except (AssertionError, RuntimeError):
684
+ pass
build/torch29-cxx11-cu129-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf70d854b4503e79ff06f7193781f26bb6d723a37c1646649b741c6d79c4f7d6
3
+ size 2870952
build/torch29-cxx11-cu129-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _deep_gemm_cuda_a68a39f
3
+ ops = torch.ops._deep_gemm_cuda_a68a39f
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_deep_gemm_cuda_a68a39f::{op_name}"
build/torch29-cxx11-cu129-x86_64-linux/deep_gemm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/common/cute_tie.cuh ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ namespace cute {
4
+
5
+ struct ignore_t {
6
+ template <typename T>
7
+ constexpr const ignore_t& operator=(T&&) const noexcept {
8
+ return *this;
9
+ }
10
+ };
11
+
12
+ inline constexpr ignore_t ignore{};
13
+
14
+ } // namespace cute
15
+
16
+ #define CUTE_TIE_CONCAT_IMPL(A, B) A##B
17
+ #define CUTE_TIE_CONCAT(A, B) CUTE_TIE_CONCAT_IMPL(A, B)
18
+
19
+ #define CUTE_TIE_GET_NTH_ARG(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) N
20
+ #define CUTE_TIE_COUNT_ARGS(...) \
21
+ CUTE_TIE_GET_NTH_ARG(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)
22
+
23
+ #define CUTE_TIE_OP_DECL(I, TUPLE, VAR) auto VAR = ::cute::get<I>(TUPLE)
24
+ #define CUTE_TIE_OP_ASSIGN(I, TUPLE, VAR) VAR = ::cute::get<I>(TUPLE)
25
+
26
+ #define CUTE_TIE_APPLY_OP_1(OP, T, V1) OP(0, T, V1);
27
+ #define CUTE_TIE_APPLY_OP_2(OP, T, V1, V2) OP(0, T, V1); OP(1, T, V2);
28
+ #define CUTE_TIE_APPLY_OP_3(OP, T, V1, V2, V3) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3);
29
+ #define CUTE_TIE_APPLY_OP_4(OP, T, V1, V2, V3, V4) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4);
30
+ #define CUTE_TIE_APPLY_OP_5(OP, T, V1, V2, V3, V4, V5) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4); OP(4, T, V5);
31
+
32
+ #define CUTE_TIE_DECL(TUPLE_EXPR, ...) \
33
+ auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \
34
+ CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \
35
+ CUTE_TIE_OP_DECL, \
36
+ CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \
37
+ __VA_ARGS__ \
38
+ )
39
+
40
+ #define CUTE_TIE(TUPLE_EXPR, ...) \
41
+ do { \
42
+ auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \
43
+ CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \
44
+ CUTE_TIE_OP_ASSIGN, \
45
+ CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \
46
+ __VA_ARGS__ \
47
+ ); \
48
+ } while (0)
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/common/epilogue_utils.cuh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <deep_gemm/common/types.hpp>
4
+ #include <deep_gemm/common/utils.cuh>
5
+
6
+ namespace deep_gemm {
7
+
8
+ struct EpilogueIdentity {
9
+ template <uint32_t STORE_BLOCK_N>
10
+ __device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) {
11
+ return n_idx;
12
+ }
13
+ };
14
+
15
+ template <uint32_t kLeft, uint32_t kMid, uint32_t kRight>
16
+ struct EpilogueHeadSplits: EpilogueIdentity {
17
+ template <uint32_t STORE_BLOCK_N>
18
+ __device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) {
19
+ DG_STATIC_ASSERT(kLeft % STORE_BLOCK_N == 0 and kMid % STORE_BLOCK_N == 0
20
+ and kRight % STORE_BLOCK_N == 0, "Invalid head splits config");
21
+ return n_idx + (n_idx + kRight) / (kLeft + kRight) * kMid;
22
+ }
23
+ };
24
+
25
+ #pragma clang diagnostic pop
26
+
27
+ } // namespace deep_gemm
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/common/reduction.cuh ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cuda_bf16.h>
4
+ #include <cuda_fp8.h>
5
+ #include <cuda/std/cstdint>
6
+ #include <cuda/std/utility>
7
+
8
+ #include <deep_gemm/common/utils.cuh>
9
+
10
+ // Operation functors
11
+ template <typename T> struct ReduceSum { __device__ T operator()(T a, T b) const { return a + b; } };
12
+ template <typename T> struct ReduceMax { __device__ T operator()(T a, T b) const { return a > b ? a : b; } };
13
+ template <typename T> struct ReduceMin { __device__ T operator()(T a, T b) const { return a < b ? a : b; } };
14
+ template <typename T> struct ReduceAnd { __device__ T operator()(T a, T b) const { return a & b; } };
15
+ template <typename T> struct ReduceOr { __device__ T operator()(T a, T b) const { return a | b; } };
16
+
17
+ // Unified reduction function
18
+ template <int kNumLanesPerGroup, bool kIntergroupReduce, typename T, typename Op>
19
+ __forceinline__ __device__ T warp_reduce(T value, Op op) {
20
+ DG_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or
21
+ kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1,
22
+ "Invalid number of lanes");
23
+ constexpr uint32_t mask = 0xffffffff;
24
+ if constexpr (kIntergroupReduce) {
25
+ if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1));
26
+ if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2));
27
+ if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4));
28
+ if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8));
29
+ if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16));
30
+ } else {
31
+ if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16));
32
+ if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8));
33
+ if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4));
34
+ if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2));
35
+ if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1));
36
+ }
37
+ return value;
38
+ }
39
+
40
+ // Convenience aliases
41
+ template <int kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>
42
+ __forceinline__ __device__ T warp_reduce_sum(T value) {
43
+ return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceSum<T>{});
44
+ }
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/common/scheduler.cuh ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <deep_gemm/common/types.hpp>
4
+ #include <deep_gemm/common/utils.cuh>
5
+
6
+ namespace deep_gemm {
7
+
8
+ enum class IndexType {
9
+ MN,
10
+ K,
11
+ SF_K,
12
+ };
13
+
14
+ template <GemmType kGemmType, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumSMs, bool kIsMulticastOnA>
15
+ static constexpr uint32_t get_num_1d_blocks_per_group() {
16
+ // Select the best from candidates
17
+ uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits<uint32_t>::max();
18
+ for (const auto& candidate: {8u, 16u}) {
19
+ const auto& usage = kIsMulticastOnA ?
20
+ candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N
21
+ candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M
22
+ if (usage < min_usage)
23
+ min_usage = usage, num_best_blocks = candidate;
24
+ }
25
+ return num_best_blocks;
26
+ }
27
+
28
+ #pragma clang diagnostic push
29
+ #pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
30
+ template <GemmType kGemmType,
31
+ uint32_t BLOCK_M, uint32_t BLOCK_N,
32
+ uint32_t kNumGroups,
33
+ uint32_t kNumMulticast, bool kIsMulticastOnA,
34
+ uint32_t kNumSMs,
35
+ uint32_t SF_K_ALIGNMENT = 512u, // for k-grouped GEMM only: 128 (SM90 float SF) or 512 (SM100 UE8M0 SF)
36
+ uint32_t kNum1DBlocksPerGroup = get_num_1d_blocks_per_group<kGemmType, BLOCK_M, BLOCK_N, kNumSMs, kIsMulticastOnA>()>
37
+ struct Scheduler {
38
+ int current_iter = -1;
39
+
40
+ // Block configs
41
+ uint32_t num_blocks;
42
+ uint32_t num_m_blocks;
43
+ uint32_t num_n_blocks;
44
+
45
+ // For SM90 multicast checks
46
+ uint32_t num_blocks_in_group;
47
+ bool is_peer_cta_alive = true;
48
+
49
+ // For grouped GEMM
50
+ int* grouped_layout;
51
+ uint32_t current_group_idx = 0;
52
+ // Only used for masked layout
53
+ uint32_t current_m_cumsum = 0;
54
+ // Only used for countiguous psum layout
55
+ uint32_t last_psum_m = 0, current_psum_m, current_m_block_cumsum = 0;
56
+ // Only used for k-grouped layout
57
+ uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0;
58
+ uint32_t next_group_idx, next_shape_k;
59
+
60
+ // Only used for k-grouped gemm
61
+ __device__ __forceinline__ void get_next_k_group(uint32_t &group_idx, uint32_t &shape_k) const {
62
+ for (; group_idx < kNumGroups; ++ group_idx) {
63
+ shape_k = __ldg(grouped_layout + group_idx);
64
+ if (shape_k > 0)
65
+ break;
66
+ }
67
+ }
68
+
69
+ // ReSharper disable once CppPossiblyUninitializedMember
70
+ __device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n, const uint32_t& shape_k,
71
+ int* grouped_layout = nullptr) {
72
+ num_m_blocks = ceil_div(shape_m, BLOCK_M);
73
+ num_n_blocks = ceil_div(shape_n, BLOCK_N);
74
+ current_shape_k = shape_k;
75
+ if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) {
76
+ num_blocks = num_m_blocks * num_n_blocks;
77
+ } else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
78
+ num_blocks = num_m_blocks * num_n_blocks;
79
+ this->grouped_layout = grouped_layout;
80
+ } else if constexpr (kGemmType == GemmType::MGroupedMasked) {
81
+ this->grouped_layout = grouped_layout;
82
+ } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
83
+ this->grouped_layout = grouped_layout;
84
+ current_psum_m = __ldg(grouped_layout);
85
+ num_m_blocks = ceil_div(current_psum_m, BLOCK_M);
86
+ } else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
87
+ this->grouped_layout = grouped_layout;
88
+ get_next_k_group(current_group_idx, current_shape_k);
89
+ next_group_idx = current_group_idx + 1;
90
+ get_next_k_group(next_group_idx, next_shape_k);
91
+ }
92
+ }
93
+
94
+ __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
95
+ DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumMulticast == 0, "Invalid group size");
96
+
97
+ // Swizzle for better L2 usages
98
+ const auto& primary_num_blocks = kIsMulticastOnA ? num_n_blocks : num_m_blocks;
99
+ const auto& secondary_num_blocks = kIsMulticastOnA ? num_m_blocks : num_n_blocks;
100
+ const auto& num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup;
101
+ const auto& group_idx = block_idx / num_blocks_per_group;
102
+ auto first_block_idx = group_idx * kNum1DBlocksPerGroup;
103
+ auto in_group_idx = block_idx % num_blocks_per_group;
104
+ num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx);
105
+
106
+ // Fix unaligned TMA multicast
107
+ // NOTES: for SM90 only, as SM90 can dynamically disable TMA multicast
108
+ // while SM100 uses 2-CTA, which can not be dynamically disabled
109
+ #if __CUDA_ARCH__ < 1000
110
+ if (kNumMulticast > 1 and num_blocks_in_group % 2 != 0) {
111
+ if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) {
112
+ num_blocks_in_group = num_blocks_in_group ^ 1;
113
+ } else {
114
+ in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks;
115
+ first_block_idx += num_blocks_in_group ^ 1;
116
+ num_blocks_in_group = 1;
117
+ }
118
+ }
119
+ #endif
120
+
121
+ // Convert to final M/N block indices
122
+ // `kIsMulticastOnA == true` leads to groups on N
123
+ if constexpr (kIsMulticastOnA) {
124
+ m_block_idx = in_group_idx / num_blocks_in_group;
125
+ n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
126
+ } else {
127
+ m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
128
+ n_block_idx = in_group_idx / num_blocks_in_group;
129
+ }
130
+ }
131
+
132
+ template <bool kWithGroupOffset, IndexType kIndexType = IndexType::MN>
133
+ __device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
134
+ const uint32_t& block_idx, const uint32_t& m_block_idx = 0) {
135
+ if constexpr (kGemmType == GemmType::Normal) {
136
+ return block_idx * block_size;
137
+ } else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
138
+ const auto offset = kWithGroupOffset ? cute::max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)) : 0;
139
+ return offset * shape_dim + block_idx * block_size;
140
+ } else if constexpr (kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
141
+ const auto offset = kWithGroupOffset ? current_group_idx : 0;
142
+ return offset * shape_dim + block_idx * block_size;
143
+ } else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
144
+ auto offset = 0;
145
+ if constexpr (kWithGroupOffset) {
146
+ if constexpr (kIndexType == IndexType::MN)
147
+ offset = current_group_idx * shape_dim;
148
+ else if constexpr (kIndexType == IndexType::K)
149
+ offset = current_k_cumsum;
150
+ else if constexpr (kIndexType == IndexType::SF_K)
151
+ offset = current_sf_k_cumsum;
152
+ }
153
+ return offset + block_idx * block_size;
154
+ } else if constexpr (kGemmType == GemmType::Batched) {
155
+ // Ignore kWithGroupOffset, and apply offset for IndexType::SF_K
156
+ const auto offset = kIndexType == IndexType::SF_K ? current_group_idx : 0;
157
+ return offset * shape_dim + block_idx * block_size;
158
+ }
159
+ }
160
+
161
+ __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
162
+ const auto next_block_idx = (++ current_iter) * kNumSMs + blockIdx.x;
163
+
164
+ if constexpr (kGemmType == GemmType::MGroupedMasked) {
165
+ while (true) {
166
+ // End of the task
167
+ if (current_group_idx == kNumGroups)
168
+ return false;
169
+
170
+ // Within current group
171
+ num_m_blocks = ceil_div(static_cast<uint32_t>(__ldg(grouped_layout + current_group_idx)), BLOCK_M);
172
+ const auto current_m_block_cumsum = current_m_cumsum + num_m_blocks;
173
+ if (next_block_idx < current_m_block_cumsum * num_n_blocks)
174
+ break;
175
+
176
+ // Move to check the next group
177
+ current_group_idx ++, current_m_cumsum = current_m_block_cumsum;
178
+ }
179
+
180
+ get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx);
181
+ } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
182
+ while (true) {
183
+ // Within current group
184
+ if (next_block_idx < (current_m_block_cumsum + num_m_blocks) * num_n_blocks)
185
+ break;
186
+
187
+ // Move to check the next group
188
+ if (++ current_group_idx == kNumGroups)
189
+ return false;
190
+
191
+ // NOTES: `num_m_blocks` varies with the increase of the group index
192
+ last_psum_m = align(current_psum_m, 128u);
193
+ current_psum_m = __ldg(grouped_layout + current_group_idx);
194
+ current_m_block_cumsum += num_m_blocks;
195
+ num_m_blocks = ceil_div(current_psum_m - last_psum_m, BLOCK_M);
196
+ }
197
+
198
+ get_swizzled_block_idx(next_block_idx - current_m_block_cumsum * num_n_blocks, m_block_idx, n_block_idx);
199
+
200
+ // NOTES: `last_psum_m` is aligned with 128
201
+ m_block_idx += last_psum_m / BLOCK_M;
202
+ DG_STATIC_ASSERT(128 % BLOCK_M == 0, "Invalid BLOCK_M");
203
+ } else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
204
+ while (true) {
205
+ // End of the task
206
+ if (current_group_idx == kNumGroups)
207
+ return false;
208
+
209
+ // Within current group
210
+ if (next_block_idx < (current_num_valid_groups + 1) * num_m_blocks * num_n_blocks)
211
+ break;
212
+
213
+ // Move to check the next group
214
+ current_k_cumsum += current_shape_k;
215
+ current_sf_k_cumsum += ceil_div(current_shape_k, SF_K_ALIGNMENT);
216
+ current_num_valid_groups ++;
217
+
218
+ current_group_idx = next_group_idx ++;
219
+ current_shape_k = next_shape_k;
220
+ get_next_k_group(next_group_idx, next_shape_k);
221
+ }
222
+
223
+ get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_m_blocks * num_n_blocks, m_block_idx, n_block_idx);
224
+ } else if constexpr (kGemmType == GemmType::Batched) {
225
+ if (next_block_idx >= num_blocks * kNumGroups)
226
+ return false;
227
+
228
+ current_group_idx = next_block_idx / num_blocks;
229
+ const auto& block_idx = next_block_idx - current_group_idx * num_blocks;
230
+ if constexpr (kIsMulticastOnA) {
231
+ m_block_idx = block_idx / num_n_blocks;
232
+ n_block_idx = block_idx % num_n_blocks;
233
+ } else {
234
+ m_block_idx = block_idx % num_m_blocks;
235
+ n_block_idx = block_idx / num_m_blocks;
236
+ }
237
+ } else {
238
+ if (next_block_idx >= num_blocks)
239
+ return false;
240
+
241
+ // For SM90 only
242
+ // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned
243
+ is_peer_cta_alive = num_n_blocks % kNumMulticast == 0 or // Always aligned on N (constant bypass)
244
+ num_m_blocks % kNumMulticast == 0 or // Always aligned on M (constant bypass)
245
+ (next_block_idx ^ 1) < num_blocks; // Peer CTA in bound
246
+ get_swizzled_block_idx(next_block_idx, m_block_idx, n_block_idx);
247
+ }
248
+ return true;
249
+ }
250
+
251
+ // For SM90 only
252
+ __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const {
253
+ if (num_blocks_in_group == 1)
254
+ return false;
255
+ if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked or
256
+ kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched) {
257
+ return true;
258
+ } else {
259
+ DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type");
260
+ if constexpr (kIsMulticastOnA) {
261
+ return true;
262
+ } else {
263
+ const auto& group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M);
264
+ const auto& peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M);
265
+ return group_idx == peer_group_idx;
266
+ }
267
+ }
268
+ }
269
+
270
+ // For SM90 only
271
+ // ReSharper disable once CppNotAllPathsReturnValue
272
+ __device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const {
273
+ if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) {
274
+ return true;
275
+ } else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
276
+ return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0;
277
+ } else if constexpr (kGemmType == GemmType::MGroupedMasked) {
278
+ return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + current_group_idx);
279
+ } else {
280
+ // Unreachable
281
+ DG_TRAP_ONLY_DEVICE_ASSERT(false);
282
+ }
283
+ }
284
+ };
285
+
286
+ #pragma clang diagnostic pop
287
+
288
+ } // namespace deep_gemm
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/common/sm100_utils.cuh ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cute/atom/mma_traits_sm100.hpp>
4
+ #include <cute/arch/mma_sm100_umma.hpp>
5
+ #include <cute/arch/tmem_allocator_sm100.hpp>
6
+ #include <cutlass/arch/barrier.h>
7
+
8
+ #include <deep_gemm/common/utils.cuh>
9
+ #include <deep_gemm/common/tma_utils.cuh>
10
+
11
+ namespace deep_gemm::sm100 {
12
+
13
+ __device__ __forceinline__
14
+ cute::UMMA::SmemDescriptor make_smem_desc(cute::UMMA::LayoutType layout, void* smem_ptr,
15
+ uint32_t stride_byte_offset, uint32_t leading_byte_offset) {
16
+ cute::UMMA::SmemDescriptor desc;
17
+
18
+ // Set the version for SM100
19
+ desc.version_ = 1;
20
+
21
+ // Legacy mode
22
+ desc.lbo_mode_ = 0;
23
+
24
+ // Layout
25
+ desc.layout_type_ = static_cast<uint8_t>(layout);
26
+
27
+ // Start address
28
+ const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr);
29
+ desc.start_address_ = static_cast<uint16_t>(uint_ptr >> 4);
30
+
31
+ // Base offset
32
+ desc.base_offset_ = 0;
33
+
34
+ // SBO and LBO
35
+ desc.stride_byte_offset_ = stride_byte_offset >> 4;
36
+ desc.leading_byte_offset_ = leading_byte_offset >> 4;
37
+
38
+ return desc;
39
+ }
40
+
41
+ __device__ __forceinline__
42
+ cute::UMMA::SmemDescriptor make_sf_desc(void* smem_ptr) {
43
+ // NOTES: the UTCCP layout is K-major by default
44
+ // Atom size: 8 x 128 bits
45
+ // {SBO, LBO} means the byte stride between atoms on {MN, K}
46
+ // Since the UTCCP we used is 128b-wide (only 1 atom on K), so LBO can be zero
47
+ return make_smem_desc(cute::UMMA::LayoutType::SWIZZLE_NONE, smem_ptr, 8 * 16, 0);
48
+ }
49
+
50
+ __device__ __forceinline__
51
+ void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_ptr) {
52
+ const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr);
53
+ desc.start_address_ = static_cast<uint16_t>(uint_ptr >> 4);
54
+ }
55
+
56
+ __device__ __forceinline__
57
+ static uint32_t get_atom_base(const cute::UMMA::LayoutType& layout_type) {
58
+ return layout_type == cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 32 : 16;
59
+ }
60
+
61
+ // ReSharper disable once CppNotAllPathsReturnValue
62
+ template <cute::UMMA::Major kMajorMode, uint32_t kSwizzleMode, bool kUseBase32, typename dtype_t>
63
+ constexpr static cute::UMMA::LayoutType to_umma_layout_type() {
64
+ DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or
65
+ kSwizzleMode == 32 or kSwizzleMode == 64 or
66
+ kSwizzleMode == 128, "Invalid swizzling mode");
67
+ // A special case
68
+ if constexpr ((cute::is_same_v<dtype_t, float> and kMajorMode == cute::UMMA::Major::MN) or kUseBase32) {
69
+ DG_STATIC_ASSERT(kUseBase32, "Invalid swizzling base");
70
+ return cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B;
71
+ }
72
+
73
+ // Normal cases
74
+ if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE;
75
+ if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE;
76
+ if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B;
77
+ if constexpr (kSwizzleMode == 64) return cute::UMMA::LayoutType::SWIZZLE_64B;
78
+ if constexpr (kSwizzleMode == 128) return cute::UMMA::LayoutType::SWIZZLE_128B;
79
+ }
80
+
81
+ template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t kSwizzleMode, typename dtype_t>
82
+ __device__ __forceinline__
83
+ constexpr uint32_t get_umma_desc_stride_k() {
84
+ return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
85
+ }
86
+
87
+ template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t kSwizzleMode, typename dtype_t>
88
+ __device__ __forceinline__
89
+ uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, const uint32_t& k_idx) {
90
+ return base + (((offset + k_idx * get_umma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>()) * static_cast<uint32_t>(sizeof(dtype_t))) >> 4u);
91
+ }
92
+
93
+ template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, bool kUseBase32 = false, typename dtype_t>
94
+ __device__ __forceinline__
95
+ cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) {
96
+ const uint32_t stride_k = get_umma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>();
97
+ const auto& layout_type = to_umma_layout_type<kMajorMode, kSwizzleMode, kUseBase32, dtype_t>();
98
+ const auto& num_non_contiguous = 128 / get_atom_base(layout_type);
99
+ if constexpr (kMajorMode == cute::UMMA::Major::K) {
100
+ // NOTES: for K-major layout, the swizzle must be the same as `BLOCK_K * sizeof(dtype_t)`
101
+ // also, atom index must be 0, so that each block has exactly one swizzle atom on the K axis
102
+ DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value");
103
+
104
+ // Atom size: 8 x `kSwizzleMode` (in bytes, on K)
105
+ // {SBO, LBO} means the byte stride between atoms on {MN, K}
106
+ // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0
107
+ const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t);
108
+ const uint32_t leading_byte_offset = 0;
109
+ return make_smem_desc(layout_type,
110
+ base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k,
111
+ stride_byte_offset, leading_byte_offset);
112
+ } else {
113
+ constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
114
+
115
+ // Must have no in-atom MN-idx
116
+ // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time
117
+ DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0);
118
+ DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling");
119
+
120
+ // Atom size: `kSwizzleMode` (in bytes, on MN) x 8
121
+ // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving
122
+ // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling
123
+ // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling
124
+ uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t);
125
+ uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t);
126
+ if constexpr (kSwizzleMode == 16)
127
+ swap(stride_byte_offset, leading_byte_offset);
128
+ return make_smem_desc(layout_type,
129
+ base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k,
130
+ stride_byte_offset, leading_byte_offset);
131
+ }
132
+ }
133
+
134
+ __device__ __forceinline__
135
+ uint64_t make_runtime_instr_desc_with_sf_id(cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sfa_id, const uint32_t& sfb_id) {
136
+ desc.a_sf_id_ = sfa_id, desc.b_sf_id_ = sfb_id;
137
+ return static_cast<uint64_t>(static_cast<uint32_t>(desc)) << 32;
138
+ }
139
+
140
+ template <uint32_t kNumCols>
141
+ __device__ constexpr uint32_t get_num_aligned_tmem_cols() {
142
+ DG_STATIC_ASSERT(kNumCols <= 512, "Too many tensor memory columns");
143
+ if (kNumCols <= 32) return 32;
144
+ if (kNumCols <= 64) return 64;
145
+ if (kNumCols <= 128) return 128;
146
+ if (kNumCols <= 256) return 256;
147
+ return 512;
148
+ }
149
+
150
+ __device__ __forceinline__ void tcgen05_before_thread_sync() {
151
+ asm volatile("tcgen05.fence::before_thread_sync;");
152
+ }
153
+
154
+ __device__ __forceinline__ void tcgen05_after_thread_sync() {
155
+ asm volatile("tcgen05.fence::after_thread_sync;");
156
+ }
157
+
158
+ __device__ __forceinline__
159
+ void tma_gather4(const void* desc_ptr, cutlass::arch::ClusterTransactionBarrier &mbarrier, void* smem_ptr, int col_idx, int4 row_idxs, uint64_t cache_hint) {
160
+ uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
161
+ uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&mbarrier);
162
+ asm volatile(
163
+ "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n"
164
+ :
165
+ : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx),
166
+ "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w),
167
+ "r"(mbarrier_addr), "l"(cache_hint)
168
+ : "memory"
169
+ );
170
+ }
171
+
172
+ // UMMA versions with relaxed assertions
173
+ struct SM100_MMA_F16BF16_SS {
174
+ __device__ static void
175
+ fma(uint64_t const& desc_a,
176
+ uint64_t const& desc_b,
177
+ uint32_t const& tmem_c,
178
+ uint32_t const& scale_c,
179
+ uint64_t const& desc) {
180
+ asm volatile(
181
+ "{\n\t"
182
+ ".reg .pred p;\n\t"
183
+ "setp.ne.b32 p, %4, 0;\n\t"
184
+ "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t"
185
+ "}\n"
186
+ :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
187
+ }
188
+ };
189
+
190
+ struct SM100_MMA_F16BF16_2x1SM_SS {
191
+ __device__ static void
192
+ fma(uint64_t const& desc_a,
193
+ uint64_t const& desc_b,
194
+ uint32_t const& tmem_c,
195
+ uint32_t const& scale_c,
196
+ uint64_t const& desc) {
197
+ asm volatile(
198
+ "{\n\t"
199
+ ".reg .pred p;\n\t"
200
+ "setp.ne.b32 p, %4, 0;\n\t"
201
+ "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, p; \n\t"
202
+ "}\n"
203
+ :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
204
+ }
205
+ };
206
+
207
+ struct SM100_MMA_MXF8F6F4_SS {
208
+ __device__ static void
209
+ fma(uint64_t const& desc_a,
210
+ uint64_t const& desc_b,
211
+ uint32_t const& tmem_c,
212
+ uint32_t const& scale_c,
213
+ uint64_t const& desc,
214
+ uint32_t const& tmem_sfa,
215
+ uint32_t const& tmem_sfb) {
216
+ asm volatile(
217
+ "{\n\t"
218
+ ".reg .pred p;\n\t"
219
+ "setp.ne.b32 p, %4, 0;\n\t"
220
+ "tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t"
221
+ "}\n"
222
+ :
223
+ : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
224
+ "r"(tmem_sfa), "r"(tmem_sfb));
225
+ }
226
+ };
227
+
228
+ struct SM100_MMA_MXF8F6F4_2x1SM_SS {
229
+ __device__ static void
230
+ fma(uint64_t const& desc_a,
231
+ uint64_t const& desc_b,
232
+ uint32_t const& tmem_c,
233
+ uint32_t const& scale_c,
234
+ uint64_t const& desc,
235
+ uint32_t const& tmem_sfa,
236
+ uint32_t const& tmem_sfb) {
237
+ asm volatile(
238
+ "{\n\t"
239
+ ".reg .pred p;\n\t"
240
+ "setp.ne.b32 p, %4, 0;\n\t"
241
+ "tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t"
242
+ "}\n"
243
+ :
244
+ : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
245
+ "r"(tmem_sfa), "r"(tmem_sfb));
246
+ }
247
+ };
248
+
249
+ struct SM100_MMA_F16BF16_WS_SS {
250
+ __device__ static void
251
+ fma(uint64_t const& desc_a,
252
+ uint64_t const& desc_b,
253
+ uint32_t const& tmem_c,
254
+ uint32_t const& scale_c,
255
+ uint64_t const& desc) {
256
+ asm volatile(
257
+ "{\n\t"
258
+ ".reg .pred p;\n\t"
259
+ "setp.ne.b32 p, %4, 0;\n\t"
260
+ "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t"
261
+ "}\n"
262
+ :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
263
+ }
264
+ };
265
+
266
+ } // namespace `deep_gemm::sm100`
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/common/sm90_utils.cuh ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cute/arch/cluster_sm90.hpp>
4
+ #include <cute/arch/mma_sm90_desc.hpp>
5
+ #include <cute/arch/mma_sm90_gmma.hpp>
6
+ #include <cute/arch/mma_sm90_gmma_ext.hpp>
7
+ #include <cute/arch/mma_sm100_desc.hpp>
8
+
9
+ #include <deep_gemm/common/utils.cuh>
10
+ #include <deep_gemm/common/sm100_utils.cuh>
11
+ #include <deep_gemm/common/tma_utils.cuh>
12
+
13
+ namespace deep_gemm::sm90 {
14
+
15
+ template <int N_, typename MMA>
16
+ struct FP8MMA {
17
+
18
+ template <size_t ...Idx>
19
+ __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence<Idx...>) {
20
+ using namespace cute::SM90::GMMA;
21
+ MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
22
+ }
23
+
24
+ __forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
25
+ call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence<N_/2>{});
26
+ }
27
+
28
+ static constexpr int M = 64;
29
+ static constexpr int N = N_;
30
+ static constexpr int K = 32;
31
+ static constexpr int kNumAccum = M * N / 128;
32
+ };
33
+
34
+ template <int N>
35
+ struct FP8MMASelector {
36
+
37
+ static constexpr auto select_mma() {
38
+ using namespace cute::SM90::GMMA;
39
+ if constexpr (N == 8) return MMA_64x8x32_F32E4M3E4M3_SS_TN();
40
+ if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN();
41
+ if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN();
42
+ if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN();
43
+ if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN();
44
+ if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN();
45
+ if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN();
46
+ if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN();
47
+ if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN();
48
+ if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN();
49
+ if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN();
50
+ if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN();
51
+ if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN();
52
+ if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN();
53
+ if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN();
54
+ if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN();
55
+ if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN();
56
+ if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN();
57
+ if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN();
58
+ if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN();
59
+ if constexpr (N == 168) return MMA_64x168x32_F32E4M3E4M3_SS_TN();
60
+ if constexpr (N == 176) return MMA_64x176x32_F32E4M3E4M3_SS_TN();
61
+ if constexpr (N == 184) return MMA_64x184x32_F32E4M3E4M3_SS_TN();
62
+ if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN();
63
+ if constexpr (N == 200) return MMA_64x200x32_F32E4M3E4M3_SS_TN();
64
+ if constexpr (N == 208) return MMA_64x208x32_F32E4M3E4M3_SS_TN();
65
+ if constexpr (N == 216) return MMA_64x216x32_F32E4M3E4M3_SS_TN();
66
+ if constexpr (N == 224) return MMA_64x224x32_F32E4M3E4M3_SS_TN();
67
+ if constexpr (N == 232) return MMA_64x232x32_F32E4M3E4M3_SS_TN();
68
+ if constexpr (N == 240) return MMA_64x240x32_F32E4M3E4M3_SS_TN();
69
+ if constexpr (N == 248) return MMA_64x248x32_F32E4M3E4M3_SS_TN();
70
+ if constexpr (N == 256) return MMA_64x256x32_F32E4M3E4M3_SS_TN();
71
+ }
72
+
73
+ static constexpr auto select_type() {
74
+ return FP8MMA<N, decltype(select_mma())>();
75
+ }
76
+
77
+ using type = decltype(select_type());
78
+ };
79
+
80
+ template <int N_, typename MMA>
81
+ struct BF16MMA {
82
+
83
+ template <size_t ...Idx>
84
+ __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence<Idx...>) {
85
+ using namespace cute::SM90::GMMA;
86
+ MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
87
+ }
88
+
89
+ __forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
90
+ call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence<N_/2>{});
91
+ }
92
+
93
+ static constexpr int M = 64;
94
+ static constexpr int N = N_;
95
+ static constexpr int K = 16;
96
+ static constexpr int kNumAccum = M * N / 128;
97
+ };
98
+
99
+ template <cute::UMMA::Major kMajor>
100
+ constexpr cute::SM90::GMMA::Major to_sm90_major() {
101
+ DG_STATIC_ASSERT(kMajor == cute::UMMA::Major::K or kMajor == cute::UMMA::Major::MN, "Invalid major-ness");
102
+ return kMajor == cute::UMMA::Major::K ? cute::SM90::GMMA::Major::K : cute::SM90::GMMA::Major::MN;
103
+ }
104
+
105
+ template <int N,
106
+ cute::UMMA::Major kMajorA = cute::UMMA::Major::K,
107
+ cute::UMMA::Major kMajorB = cute::UMMA::Major::K>
108
+ struct BF16MMASelector {
109
+
110
+ static constexpr auto select_mma() {
111
+ using namespace cute::SM90::GMMA;
112
+ constexpr auto kGMMAMajorA = to_sm90_major<kMajorA>();
113
+ constexpr auto kGMMAMajorB = to_sm90_major<kMajorB>();
114
+ if constexpr (N == 8) return MMA_64x8x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
115
+ if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
116
+ if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
117
+ if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
118
+ if constexpr (N == 40) return MMA_64x40x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
119
+ if constexpr (N == 48) return MMA_64x48x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
120
+ if constexpr (N == 56) return MMA_64x56x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
121
+ if constexpr (N == 64) return MMA_64x64x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
122
+ if constexpr (N == 72) return MMA_64x72x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
123
+ if constexpr (N == 80) return MMA_64x80x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
124
+ if constexpr (N == 88) return MMA_64x88x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
125
+ if constexpr (N == 96) return MMA_64x96x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
126
+ if constexpr (N == 104) return MMA_64x104x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
127
+ if constexpr (N == 112) return MMA_64x112x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
128
+ if constexpr (N == 120) return MMA_64x120x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
129
+ if constexpr (N == 128) return MMA_64x128x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
130
+ if constexpr (N == 136) return MMA_64x136x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
131
+ if constexpr (N == 144) return MMA_64x144x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
132
+ if constexpr (N == 152) return MMA_64x152x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
133
+ if constexpr (N == 160) return MMA_64x160x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
134
+ if constexpr (N == 168) return MMA_64x168x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
135
+ if constexpr (N == 176) return MMA_64x176x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
136
+ if constexpr (N == 184) return MMA_64x184x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
137
+ if constexpr (N == 192) return MMA_64x192x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
138
+ if constexpr (N == 200) return MMA_64x200x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
139
+ if constexpr (N == 208) return MMA_64x208x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
140
+ if constexpr (N == 216) return MMA_64x216x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
141
+ if constexpr (N == 224) return MMA_64x224x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
142
+ if constexpr (N == 232) return MMA_64x232x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
143
+ if constexpr (N == 240) return MMA_64x240x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
144
+ if constexpr (N == 248) return MMA_64x248x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
145
+ if constexpr (N == 256) return MMA_64x256x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
146
+ }
147
+
148
+ static constexpr auto select_type() {
149
+ return BF16MMA<N, decltype(select_mma())>();
150
+ }
151
+
152
+ using type = decltype(select_type());
153
+ };
154
+
155
+ template <int N_, typename MMA>
156
+ struct TF32MMARS {
157
+
158
+ template <size_t ...Idx>
159
+ __forceinline__ __device__ static void call_fma_impl(uint32_t* a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence<Idx...>) {
160
+ using namespace cute::SM90::GMMA;
161
+ MMA::fma(a[0], a[1], a[2], a[3], desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
162
+ }
163
+
164
+ __forceinline__ __device__ static void wgmma(float* a, uint64_t const& desc_b, float* d, bool scale_d) {
165
+ call_fma_impl(reinterpret_cast<uint32_t*>(a), desc_b, d, scale_d, cute::make_index_sequence<N_/2>{});
166
+ }
167
+
168
+ static constexpr int M = 64;
169
+ static constexpr int N = N_;
170
+ static constexpr int K = 8;
171
+ static constexpr int kNumAccum = M * N / 128;
172
+ };
173
+
174
+ template <int N, bool kUseRS = true>
175
+ struct TF32MMASelector {
176
+
177
+ static constexpr auto select_mma() {
178
+ using namespace cute::SM90::GMMA;
179
+ if constexpr (kUseRS) {
180
+ if constexpr (N == 8) return MMA_64x8x8_F32TF32TF32_RS_TN();
181
+ if constexpr (N == 16) return MMA_64x16x8_F32TF32TF32_RS_TN();
182
+ if constexpr (N == 32) return MMA_64x32x8_F32TF32TF32_RS_TN();
183
+ if constexpr (N == 64) return MMA_64x64x8_F32TF32TF32_RS_TN();
184
+ if constexpr (N == 128) return MMA_64x128x8_F32TF32TF32_RS_TN();
185
+ if constexpr (N == 256) return MMA_64x256x8_F32TF32TF32_RS_TN();
186
+ DG_STATIC_ASSERT(N == 8 or N == 16 or N == 32 or N == 64 or N == 128 or N == 256, "Invalid N");
187
+ }
188
+ }
189
+
190
+ static constexpr auto select_type() {
191
+ if constexpr (kUseRS) {
192
+ return TF32MMARS<N, decltype(select_mma())>();
193
+ } else {
194
+ DG_STATIC_ASSERT(kUseRS, "SS mode is not supported for TF32MMASelector for now");
195
+ }
196
+ }
197
+
198
+ using type = decltype(select_type());
199
+ };
200
+
201
+ template <typename dtype_t>
202
+ struct SM90_U32x2_STSM_N {
203
+ __device__ __forceinline__ static void
204
+ copy(dtype_t src_0, dtype_t src_1, void* smem_dst) {
205
+ const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
206
+ asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
207
+ :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1]));
208
+ }
209
+ };
210
+
211
+ struct SM90_U32x2_LDSM_N {
212
+ __device__ __forceinline__ static void
213
+ copy(uint32_t& dst_0, uint32_t& dst_1, void* smem_src) {
214
+ asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
215
+ : "=r"(dst_0), "=r"(dst_1)
216
+ : "l"(__cvta_generic_to_shared(smem_src)));
217
+ }
218
+ };
219
+
220
+ struct SM90_U32x4_LDSM_N {
221
+ __device__ __forceinline__ static void
222
+ copy(uint32_t& dst_0, uint32_t& dst_1, uint32_t& dst_2, uint32_t& dst_3, void* smem_src) {
223
+ asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
224
+ : "=r"(dst_0), "=r"(dst_1), "=r"(dst_2), "=r"(dst_3)
225
+ : "l"(__cvta_generic_to_shared(smem_src)));
226
+ }
227
+ };
228
+
229
+ __forceinline__ __device__ void warpgroup_arrive() {
230
+ asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
231
+ }
232
+
233
+ __forceinline__ __device__ void warpgroup_commit_batch() {
234
+ asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory");
235
+ }
236
+
237
+ __forceinline__ __device__ void warpgroup_fence_operand(float& reg) {
238
+ asm volatile("" : "+f"(reg) :: "memory");
239
+ }
240
+
241
+ template <int N>
242
+ __forceinline__ __device__ void warpgroup_wait() {
243
+ DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
244
+ asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory");
245
+ }
246
+
247
+ template <class PointerType>
248
+ __device__ cute::GmmaDescriptor make_smem_desc(PointerType smem_ptr, const int& layout_type,
249
+ const int& leading_byte_offset = 0,
250
+ const int& stride_byte_offset = 1024) {
251
+ // NOTES: the default LBO and SBO are for K-major types
252
+ cute::GmmaDescriptor desc;
253
+ const auto& uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
254
+ desc.bitfield.start_address_ = uint_ptr >> 4;
255
+ desc.bitfield.layout_type_ = layout_type;
256
+ desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
257
+ desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
258
+ desc.bitfield.base_offset_ = 0;
259
+ return desc;
260
+ }
261
+
262
+ template <uint32_t BLOCK_INNER, uint32_t kSwizzleMode, typename dtype_t>
263
+ constexpr uint32_t get_inner_block_atom_size() {
264
+ return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t);
265
+ }
266
+
267
+ template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t kSwizzleMode, typename dtype_t>
268
+ __device__ __forceinline__
269
+ constexpr uint32_t get_gmma_desc_stride_k() {
270
+ return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
271
+ }
272
+
273
+ // ReSharper disable once CppNotAllPathsReturnValue
274
+ template <cute::UMMA::Major kMajorMode, uint32_t kSwizzleMode, typename dtype_t>
275
+ constexpr static cute::SM90::GMMA::LayoutType to_gmma_layout_type() {
276
+ DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or
277
+ kSwizzleMode == 32 or kSwizzleMode == 64 or
278
+ kSwizzleMode == 128, "Invalid swizzling mode");
279
+
280
+ // Normal cases
281
+ if constexpr (kSwizzleMode == 0) return cute::SM90::GMMA::LayoutType::INTERLEAVE;
282
+ if constexpr (kSwizzleMode == 16) return cute::SM90::GMMA::LayoutType::INTERLEAVE;
283
+ if constexpr (kSwizzleMode == 32) return cute::SM90::GMMA::LayoutType::B32;
284
+ if constexpr (kSwizzleMode == 64) return cute::SM90::GMMA::LayoutType::B64;
285
+ if constexpr (kSwizzleMode == 128) return cute::SM90::GMMA::LayoutType::B128;
286
+ }
287
+
288
+ template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, typename dtype_t>
289
+ __device__ __forceinline__
290
+ uint32_t advance_gmma_desc_lo(const uint32_t& base, const uint32_t& mn_idx, const uint32_t& k_idx, const uint32_t& offset = 0) {
291
+ return base + (((offset + mn_idx * BLOCK_K + k_idx * get_gmma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>()) * static_cast<uint32_t>(sizeof(dtype_t))) >> 4u);
292
+ }
293
+
294
+ template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, typename dtype_t>
295
+ __device__ __forceinline__
296
+ cute::GmmaDescriptor make_gmma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) {
297
+ const uint32_t stride_k = get_gmma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>();
298
+ const auto& layout_type = to_gmma_layout_type<kMajorMode, kSwizzleMode, dtype_t>();
299
+ constexpr uint32_t num_non_contiguous = 128 / 16;
300
+ if constexpr (kMajorMode == cute::UMMA::Major::K) {
301
+ // NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128
302
+ DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value");
303
+
304
+ // Atom size: 8 x `kSwizzleMode` (in bytes, on K)
305
+ // {SBO, LBO} means the byte stride between atoms on {MN, K}
306
+ // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0
307
+ const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t);
308
+ const uint32_t leading_byte_offset = 0;
309
+ return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast<uint32_t>(layout_type),
310
+ leading_byte_offset, stride_byte_offset);
311
+ } else {
312
+ constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
313
+
314
+ // Must have no in-atom MN-idx
315
+ // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time
316
+ DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0);
317
+ DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling");
318
+
319
+ // Atom size: `kSwizzleMode` (in bytes, on MN) x 8
320
+ // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving
321
+ // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling
322
+ // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling
323
+ uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t);
324
+ uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t);
325
+ if constexpr (kSwizzleMode == 16)
326
+ swap(stride_byte_offset, leading_byte_offset);
327
+ return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast<uint32_t>(layout_type),
328
+ leading_byte_offset, stride_byte_offset);
329
+ }
330
+ }
331
+
332
+ } // namespace `deep_gemm::sm90`
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/common/tma_utils.cuh ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cute/arch/copy_sm90_tma.hpp>
4
+ #include <cute/arch/copy_sm100_tma.hpp>
5
+ #include <cutlass/arch/barrier.h>
6
+
7
+ namespace deep_gemm {
8
+
9
+ template <uint32_t BLOCK_INNER, uint32_t kSwizzleMode, typename dtype_t>
10
+ constexpr uint32_t get_inner_block_atom_size() {
11
+ return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t);
12
+ }
13
+
14
+ template <uint32_t BLOCK_INNER, uint32_t BLOCK_OUTER,
15
+ uint32_t kSwizzleMode,
16
+ typename dtype_t, bool kIs3DTMA = false>
17
+ __device__ __forceinline__ void
18
+ tma_copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr,
19
+ dtype_t* smem_ptr, const uint32_t& inner_idx, const uint32_t& outer_idx,
20
+ const uint32_t& num_tma_multicast = 1, const uint32_t& batch_idx = 0) {
21
+ DG_STATIC_ASSERT(static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL) ==
22
+ static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint");
23
+ constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size<BLOCK_INNER, kSwizzleMode, dtype_t>();
24
+
25
+ if constexpr (not kIs3DTMA) {
26
+ if (num_tma_multicast == 1) {
27
+ #pragma unroll
28
+ for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
29
+ cute::SM90_TMA_LOAD_2D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
30
+ static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
31
+ smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
32
+ inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
33
+ }
34
+ } else {
35
+ #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000))
36
+ // 2-CTA function will send signals to the leader CTA only
37
+ #pragma unroll
38
+ for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
39
+ cute::SM100_TMA_2SM_LOAD_2D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
40
+ static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
41
+ smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
42
+ inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
43
+ }
44
+ #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900))
45
+ if (cute::block_rank_in_cluster() == 0) {
46
+ #pragma unroll
47
+ for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
48
+ cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
49
+ (1 << num_tma_multicast) - 1, static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL),
50
+ smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
51
+ inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
52
+ }
53
+ }
54
+ #endif
55
+ }
56
+ } else {
57
+ if (num_tma_multicast == 1) {
58
+ #pragma unroll
59
+ for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
60
+ cute::SM90_TMA_LOAD_3D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
61
+ static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
62
+ smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
63
+ inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx);
64
+ }
65
+ } else {
66
+ #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000))
67
+ // 2-CTA function will send signals to the leader CTA only
68
+ #pragma unroll
69
+ for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
70
+ cute::SM100_TMA_2SM_LOAD_3D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
71
+ static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
72
+ smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
73
+ inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx);
74
+ }
75
+ #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900))
76
+ if (cute::block_rank_in_cluster() == 0) {
77
+ #pragma unroll
78
+ for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
79
+ cute::SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
80
+ (1 << num_tma_multicast) - 1, static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL),
81
+ smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
82
+ inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx);
83
+ }
84
+ }
85
+ #endif
86
+ }
87
+ }
88
+ }
89
+
90
+ // Tensormap related
91
+ __device__ __forceinline__ void tensor_map_release_cta() {
92
+ asm volatile ("fence.proxy.tensormap::generic.release.cta;");
93
+ }
94
+
95
+ __device__ __forceinline__ void tensor_map_acquire_cta(const cute::TmaDescriptor* gmem_desc_ptr) {
96
+ auto gmem_int_desc = reinterpret_cast<uint64_t>(gmem_desc_ptr);
97
+ asm volatile ("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" :: "l"(gmem_int_desc) : "memory");
98
+ }
99
+
100
+ __device__ __forceinline__ void tensor_map_replace_global_addr_in_smem(cute::TmaDescriptor* smem_desc, const void* new_addr) {
101
+ auto smem_int_desc = static_cast<uint32_t>(__cvta_generic_to_shared(smem_desc));
102
+ const auto new_int64_addr = reinterpret_cast<uint64_t>(new_addr);
103
+ asm volatile ("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" :: "r"(smem_int_desc), "l"(new_int64_addr));
104
+ }
105
+
106
+ __device__ __forceinline__ void tensor_map_replace_global_inner_dim_stride_in_smem(cute::TmaDescriptor* smem_desc, const uint32_t& new_dim, const uint64_t& new_stride) {
107
+ auto smem_int_desc = __cvta_generic_to_shared(smem_desc);
108
+ asm volatile ("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" :: "l"(smem_int_desc), "r"(new_dim));
109
+ #if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 3)))
110
+ asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride));
111
+ #else
112
+ DG_STATIC_ASSERT(false, "Invalid CUDA version");
113
+ #endif
114
+ }
115
+
116
+ } // namespace `deep_gemm`
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/common/types.hpp ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ namespace deep_gemm {
4
+
5
+ enum class MmaKind {
6
+ BF16 = 0,
7
+ MXFP8FP4 = 1,
8
+ };
9
+
10
+ constexpr __host__ __device__ int get_element_size(const MmaKind& mma_kind) {
11
+ switch (mma_kind) {
12
+ case MmaKind::BF16: return 2;
13
+ case MmaKind::MXFP8FP4: return 1;
14
+ default: return 0;
15
+ }
16
+ }
17
+
18
+ enum class GemmType {
19
+ Normal = 0,
20
+ MGroupedContiguous = 1,
21
+ MGroupedMasked = 2,
22
+ KGroupedContiguous = 3,
23
+ Batched = 4,
24
+ MGroupedContiguousWithPsumLayout = 5,
25
+ };
26
+
27
+ constexpr __host__ __device__ bool is_m_grouped_contiguous(const GemmType& gemm_type) {
28
+ switch (gemm_type) {
29
+ case GemmType::MGroupedContiguous: return true;
30
+ case GemmType::MGroupedContiguousWithPsumLayout: return true;
31
+ default: return false;
32
+ }
33
+ }
34
+
35
+ enum class KernelType {
36
+ Kernel1D1D = 0,
37
+ Kernel1D2D = 1,
38
+ KernelNoSF = 2
39
+ };
40
+
41
+ } // namespace deep_gemm
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/common/utils.cuh ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cuda_bf16.h>
4
+ #include <cuda_fp8.h>
5
+ #include <cuda/std/cstdint>
6
+ #include <cuda/std/utility>
7
+ #include <cute/container/tuple.hpp>
8
+
9
+ #include "cute_tie.cuh"
10
+
11
+ #ifdef __CLION_IDE__
12
+
13
+ __host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) {
14
+ asm volatile("trap;");
15
+ }
16
+
17
+ #define printf host_device_printf
18
+ #endif
19
+
20
+ #ifndef DG_DEVICE_ASSERT
21
+ #define DG_DEVICE_ASSERT(cond) \
22
+ do { \
23
+ if (not (cond)) { \
24
+ printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
25
+ asm("trap;"); \
26
+ } \
27
+ } while (0)
28
+ #endif
29
+
30
+ #ifndef DG_TRAP_ONLY_DEVICE_ASSERT
31
+ #define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \
32
+ do { \
33
+ if (not (cond)) \
34
+ asm("trap;"); \
35
+ } while (0)
36
+ #endif
37
+
38
+ #ifndef DG_STATIC_ASSERT
39
+ #define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__)
40
+ #endif
41
+
42
+ namespace deep_gemm {
43
+
44
+ template <typename FuncT>
45
+ struct PatternVisitor {
46
+ FuncT func;
47
+
48
+ __device__ __host__
49
+ explicit PatternVisitor(FuncT&& func): func(std::forward<FuncT>(func)) {}
50
+
51
+ __device__ __host__
52
+ auto operator [](const uint32_t& i) {
53
+ return func(i);
54
+ }
55
+ };
56
+
57
+ template <typename T>
58
+ __device__ __host__ T ceil_div(T a, T b) {
59
+ return (a + b - 1) / b;
60
+ }
61
+
62
+ template <typename T>
63
+ __device__ __host__ constexpr T constexpr_ceil_div(T a, T b) {
64
+ return (a + b - 1) / b;
65
+ }
66
+
67
+ template <typename T>
68
+ __device__ __host__ T align(T a, T b) {
69
+ return ceil_div(a, b) * b;
70
+ }
71
+
72
+ template <typename T>
73
+ __device__ __host__ constexpr T constexpr_align(T a, T b) {
74
+ return constexpr_ceil_div(a, b) * b;
75
+ }
76
+
77
+ template <typename T>
78
+ __device__ __host__ constexpr T constexpr_gcd(T a, T b) {
79
+ return b == 0 ? a : constexpr_gcd(b, a % b);
80
+ }
81
+
82
+ template<typename T>
83
+ __forceinline__ __device__ void swap(T& a, T& b) {
84
+ T temp = a;
85
+ a = b;
86
+ b = temp;
87
+ }
88
+
89
+ __forceinline__ __device__ uint32_t get_sm_idx() {
90
+ uint32_t sm_idx;
91
+ asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx));
92
+ return sm_idx;
93
+ }
94
+
95
+ __forceinline__ __device__ uint32_t get_lane_idx() {
96
+ uint32_t lane_id;
97
+ asm ("mov.u32 %0, %laneid;" : "=r"(lane_id));
98
+ return lane_id;
99
+ }
100
+
101
+ __device__ __forceinline__ uint32_t ld_shared(const uint32_t* ptr) {
102
+ uint32_t ret;
103
+ asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(__cvta_generic_to_shared(ptr)));
104
+ return ret;
105
+ }
106
+
107
+ __device__ __forceinline__ float2 ld_shared(const float2* ptr) {
108
+ float2 ret;
109
+ asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(__cvta_generic_to_shared(ptr)));
110
+ return ret;
111
+ }
112
+
113
+ __device__ __forceinline__ float4 ld_shared(const float4* ptr) {
114
+ float4 ret;
115
+ asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(__cvta_generic_to_shared(ptr)));
116
+ return ret;
117
+ }
118
+
119
+ __device__ __forceinline__ uint4 ld_shared(const uint4* ptr) {
120
+ uint4 ret;
121
+ asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(__cvta_generic_to_shared(ptr)));
122
+ return ret;
123
+ }
124
+
125
+ __device__ __forceinline__ float ld_shared(const float* ptr) {
126
+ float ret;
127
+ asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(__cvta_generic_to_shared(ptr)));
128
+ return ret;
129
+ }
130
+
131
+ __device__ __forceinline__ void st_shared(const float* ptr, float val) {
132
+ asm volatile("st.shared.f32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val));
133
+ }
134
+
135
+ __device__ __forceinline__ void st_shared(const float2* ptr, float2 val) {
136
+ asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val.x), "f"(val.y));
137
+ }
138
+
139
+ __device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
140
+ asm volatile("st.shared.u32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "r"(val));
141
+ }
142
+
143
+ __device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y) {
144
+ asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y));
145
+ }
146
+
147
+ __device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) {
148
+ asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w));
149
+ }
150
+
151
+ __device__ __forceinline__ void st_shared(const __int128_t* ptr, __int128_t val) {
152
+ asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val));
153
+ }
154
+
155
+ template <typename old_t>
156
+ __device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) {
157
+ auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast<float*>(&x), *reinterpret_cast<float*>(&y)});
158
+ return *reinterpret_cast<int*>(&bf16x2);
159
+ }
160
+
161
+ __device__ __forceinline__ void prefetch_l1(void *ptr) {
162
+ asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr));
163
+ }
164
+
165
+ template <uint32_t kNumBytes>
166
+ struct Vectorized {
167
+ static auto zeros() {
168
+ // TODO: add `ulonglong4` for SM100 once `__ldg` support this
169
+ if constexpr (kNumBytes > 0 and kNumBytes % 16 == 0) {
170
+ return make_uint4(0, 0, 0, 0);
171
+ } else if constexpr (kNumBytes > 0 and kNumBytes % 8 == 0) {
172
+ return make_uint2(0, 0);
173
+ } else if constexpr (kNumBytes > 0 and kNumBytes % 4 == 0) {
174
+ return 0;
175
+ } else {
176
+ DG_STATIC_ASSERT(kNumBytes > 0 and kNumBytes % 4 == 0, "Invalid vectorization");
177
+ }
178
+ }
179
+
180
+ using vec_t = decltype(zeros());
181
+ };
182
+
183
+ } // namespace `deep_gemm`
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm100_bf16_gemm.cuh ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #pragma clang diagnostic push
3
+ #pragma clang diagnostic ignored "-Wunknown-attributes"
4
+
5
+ #include <cutlass/arch/barrier.h>
6
+
7
+ #include <deep_gemm/common/scheduler.cuh>
8
+ #include <deep_gemm/common/utils.cuh>
9
+ #include <deep_gemm/common/sm100_utils.cuh>
10
+
11
+ namespace deep_gemm {
12
+
13
+ using namespace deep_gemm::sm100;
14
+
15
+ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
16
+ uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
17
+ uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K_,
18
+ uint32_t kNumGroups,
19
+ uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleCDMode,
20
+ uint32_t kNumStages_,
21
+ uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
22
+ uint32_t kNumMulticast, bool kIsMulticastOnA,
23
+ uint32_t kNumSMs,
24
+ GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t,
25
+ uint64_t kTensorCoreUtilControl>
26
+ __global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
27
+ sm100_bf16_gemm_impl(int* grouped_layout,
28
+ uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
29
+ const __grid_constant__ cute::TmaDescriptor tensor_map_a,
30
+ const __grid_constant__ cute::TmaDescriptor tensor_map_b,
31
+ const __grid_constant__ cute::TmaDescriptor tensor_map_cd) {
32
+ #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
33
+ // Enlarge `BLOCK_K` for some cases
34
+ // NOTES: this is for reducing the `umma_arrive()` overhead
35
+ constexpr bool kDoMergeStages =
36
+ kNumStages_ >= 8 and kGemmType == GemmType::Normal and
37
+ kMajorA == cute::UMMA::Major::K and kMajorB == cute::UMMA::Major::K;
38
+ // Ensure there are at least `kNumMinStages` stages after merge
39
+ constexpr uint32_t kNumMinStages = 8;
40
+ constexpr uint32_t kNumStagesPerMerge = kDoMergeStages ? kNumStages_ / kNumMinStages : 1;
41
+ constexpr uint32_t BLOCK_K = BLOCK_K_ * kNumStagesPerMerge;
42
+ constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge;
43
+
44
+ using Barrier = cutlass::arch::ClusterTransactionBarrier;
45
+ using Allocator = cute::conditional_t<kNumMulticast == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
46
+
47
+ // GEMM with accumulation must have FP32 output
48
+ if constexpr (kWithAccumulation)
49
+ DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
50
+
51
+ // Configs
52
+ constexpr uint32_t LAYOUT_AD_M = 128;
53
+ constexpr uint32_t WAVE_BLOCK_M = cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
54
+ constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M;
55
+ constexpr uint32_t kNumTMAStoreStages = 2;
56
+ DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K");
57
+ DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M");
58
+ DG_STATIC_ASSERT(sizeof(cutlass::bfloat16_t) * LAYOUT_AD_M % kSwizzleAMode == 0, "Invalid swizzle A mode");
59
+
60
+ // Overwrite shape constants if the compiler gives
61
+ shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
62
+ shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
63
+ shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
64
+
65
+ // Utils
66
+ bool is_leader_cta = cute::block_rank_in_cluster() == 0;
67
+ const auto warp_idx = cutlass::canonical_warp_idx_sync();
68
+ const auto lane_idx = get_lane_idx();
69
+
70
+ // Align to 1024 bytes for swizzle-128B
71
+ extern __shared__ __align__(1024) uint8_t smem_buffer[];
72
+
73
+ // 2-CTA MMA
74
+ constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
75
+ constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
76
+ constexpr uint32_t STORE_BLOCK_M = cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
77
+ constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t);
78
+ constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M;
79
+ DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast");
80
+ DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D");
81
+ DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast");
82
+ DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M");
83
+
84
+ // Share memory sizes
85
+ constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode;
86
+ constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
87
+ constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t);
88
+ constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t);
89
+ DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0,
90
+ "Shared memory of A/B must be aligned to 1024 bytes");
91
+ DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages");
92
+
93
+ // NOTES: Make sure we have enough shared memory for UMMA padding
94
+ static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16);
95
+ DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA");
96
+
97
+ // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size
98
+ // TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2`
99
+ constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N) > 512 ? 1 : 2;
100
+
101
+ // Real tensor memory size and offsets
102
+ constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N;
103
+ constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<kNumAccumTmemCols>();
104
+
105
+ // Prefetch TMA descriptors at the very beginning
106
+ if (warp_idx == 0 and cute::elect_one_sync()) {
107
+ cute::prefetch_tma_descriptor(&tensor_map_a);
108
+ cute::prefetch_tma_descriptor(&tensor_map_b);
109
+ cute::prefetch_tma_descriptor(&tensor_map_cd);
110
+ }
111
+
112
+ // D/A/B shared memory
113
+ auto smem_cd = PatternVisitor([&](const uint32_t& i) {
114
+ return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
115
+ });
116
+ auto smem_a = PatternVisitor([&](const uint32_t& i) {
117
+ return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
118
+ });
119
+ auto smem_b = PatternVisitor([&](const uint32_t& i) {
120
+ return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
121
+ });
122
+
123
+ // Fill barriers
124
+ auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
125
+ auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
126
+ auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
127
+ auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
128
+ auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); });
129
+ auto tensor_core_full_barrier = barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2;
130
+
131
+ // Fill the tensor memory pointer
132
+ auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2 + 1);
133
+ DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
134
+
135
+ // Initialize barriers
136
+ if (warp_idx == 1 and cute::elect_one_sync()) {
137
+ #pragma unroll
138
+ for (uint32_t i = 0; i < kNumStages; ++ i) {
139
+ // Arrive only at the leader CTA
140
+ full_barriers[i]->init(kNumMulticast);
141
+ // Arrive at all CTAs
142
+ empty_barriers[i]->init(1);
143
+ }
144
+ #pragma unroll
145
+ for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) {
146
+ // Arrive at all CTAs
147
+ tmem_full_barriers[i]->init(1);
148
+ // Arrive only at the leader CTA
149
+ tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads);
150
+ }
151
+ if constexpr (kTensorCoreUtilControl < 100)
152
+ tensor_core_full_barrier->init(1);
153
+
154
+ // Make initialized barrier visible in async proxy
155
+ cutlass::arch::fence_barrier_init();
156
+ } else if (warp_idx == 2) {
157
+ // Allocate tensor memory
158
+ Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
159
+ }
160
+ kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
161
+
162
+ // Block scheduler
163
+ uint32_t m_block_idx, n_block_idx;
164
+ auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
165
+
166
+ // Pipeline and TMA phases
167
+ uint32_t stage_idx = 0, phase = 0, tensor_core_phase = 0;
168
+ auto advance_pipeline = [&](uint32_t& k_block_idx) {
169
+ ++ k_block_idx;
170
+
171
+ // Flip phases only if reach the next first stage
172
+ stage_idx = (stage_idx + 1) % kNumStages;
173
+ phase ^= stage_idx == 0;
174
+ };
175
+
176
+ // Dispatch warps into different roles
177
+ if (warp_idx == 0 and cute::elect_one_sync()) {
178
+ // TMA load warp
179
+ // Persistently schedule over blocks
180
+ while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
181
+ const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
182
+ for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
183
+ // Wait consumer release
184
+ empty_barriers[stage_idx]->wait(phase ^ 1);
185
+
186
+ // Compute offsets
187
+ // NOTES: the group is always concatenated with the outer dimension
188
+ uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> (
189
+ shape_m, BLOCK_M, m_block_idx);
190
+ uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> (
191
+ shape_n, BLOCK_N, n_block_idx, m_block_idx);
192
+
193
+ // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
194
+ // And for all m-grouped GEMMs, A must be K-majored
195
+ DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or
196
+ kMajorA == cute::UMMA::Major::K, "Invalid major");
197
+ uint32_t k_idx = k_block_idx * BLOCK_K;
198
+ uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> (
199
+ shape_k, BLOCK_K, k_block_idx, m_block_idx);
200
+ uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> (
201
+ shape_k, BLOCK_K, k_block_idx, m_block_idx);
202
+
203
+ // Add 2 CTA offsets
204
+ if constexpr (kNumMulticast > 1) {
205
+ m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
206
+ n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
207
+ }
208
+
209
+ // Issue TMAs
210
+ constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
211
+ const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
212
+ if constexpr (kMajorA == cute::UMMA::Major::K)
213
+ tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
214
+ &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, kNumMulticast, batch_idx);
215
+ if constexpr (kMajorA == cute::UMMA::Major::MN)
216
+ tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
217
+ &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, kNumMulticast, batch_idx);
218
+ if constexpr (kMajorB == cute::UMMA::Major::K)
219
+ tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
220
+ &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, kNumMulticast, batch_idx);
221
+ if constexpr (kMajorB == cute::UMMA::Major::MN)
222
+ tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
223
+ &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, kNumMulticast, batch_idx);
224
+
225
+ // Arrive at full barriers
226
+ constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
227
+ if (is_leader_cta) {
228
+ full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast);
229
+ } else {
230
+ full_barriers[stage_idx]->arrive(0u);
231
+ }
232
+ }
233
+ }
234
+ } else if (warp_idx == 1 and is_leader_cta) {
235
+ // MMA issue warp
236
+ // NOTES: only the leader CTA will do this
237
+ // Make instruction descriptor
238
+ // TODO: refactor `UMMA_M` calculation
239
+ constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast);
240
+ constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1);
241
+ constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t);
242
+ auto instr_desc = cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, kMajorA, kMajorB>();
243
+
244
+ DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
245
+ // Merged stages only happens in NT normal GEMM cases
246
+ constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge;
247
+ auto a_desc = make_umma_desc<kMajorA, LOAD_BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode>(smem_a[0], 0, 0);
248
+ auto b_desc = make_umma_desc<kMajorB, LOAD_BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode>(smem_b[0], 0, 0);
249
+ uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
250
+ uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
251
+
252
+ // Checks for MMA instructions
253
+ // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
254
+ DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
255
+ (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
256
+ (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
257
+ "Invalid MMA instruction shape");
258
+
259
+ // Persistently schedule over blocks
260
+ while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
261
+ // Wait tensor memory empty barrier arrival
262
+ auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
263
+ auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
264
+ tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
265
+ tcgen05_after_thread_sync();
266
+
267
+ // UMMA and empty barrier arrival alias
268
+ auto umma_arrive = [](const uint64_t* barrier) {
269
+ if constexpr (kNumMulticast == 1) {
270
+ cutlass::arch::umma_arrive(barrier);
271
+ } else {
272
+ constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
273
+ cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
274
+ }
275
+ };
276
+ auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) {
277
+ umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
278
+
279
+ // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
280
+ if (do_tmem_full_arrive)
281
+ umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
282
+ };
283
+
284
+ // Launch MMAs
285
+ const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
286
+ for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
287
+ // Wait TMA arrival
288
+ full_barriers[stage_idx]->wait(phase);
289
+ tcgen05_after_thread_sync();
290
+
291
+ // Issue UMMA in the leader CTA
292
+ using mma_t = cute::conditional_t<kNumMulticast == 1, SM100_MMA_F16BF16_SS, SM100_MMA_F16BF16_2x1SM_SS>;
293
+ const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
294
+ const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast<int>(stage_idx));
295
+ const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
296
+ if (cute::elect_one_sync()) {
297
+ #pragma unroll
298
+ for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
299
+ uint32_t atom_k_idx = k * UMMA_K / BLOCK_ATOM_K;
300
+ b_desc.lo = advance_umma_desc_lo<kMajorB, LOAD_BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t>(b_desc_base_lo, atom_k_idx * LOAD_BLOCK_N * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K);
301
+ #pragma unroll
302
+ for (uint32_t w = 0; w < kNumMWaves; ++ w) {
303
+ DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset");
304
+ a_desc.lo = advance_umma_desc_lo<kMajorA, LOAD_BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t>(a_desc_base_lo, atom_k_idx * LOAD_BLOCK_M * BLOCK_ATOM_K + w * WAVE_BLOCK_M * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K);
305
+ mma_t::fma(a_desc, b_desc,
306
+ accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
307
+ k_block_idx > 0 or k > 0,
308
+ runtime_instr_desc);
309
+ }
310
+ }
311
+ }
312
+
313
+ // Commit to the mbarrier object
314
+ // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
315
+ empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1);
316
+
317
+ // Let tensor cores relax for lower possibility of frequency drop
318
+ DG_STATIC_ASSERT(kTensorCoreUtilControl > 0, "Invalid tensor utilization control");
319
+ if constexpr (kTensorCoreUtilControl < 100) {
320
+ // For utilization control
321
+ umma_arrive(reinterpret_cast<uint64_t*>(tensor_core_full_barrier));
322
+
323
+ // Wait for last UMMA to be done
324
+ tensor_core_full_barrier->wait(tensor_core_phase);
325
+ tensor_core_phase ^= 1;
326
+
327
+ // Sleep for certain cycles
328
+ constexpr static uint64_t kNumUMMACycles = (2ull * LAYOUT_AD_M * kNumMWaves * BLOCK_N * BLOCK_K) / 8192ull;
329
+ constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl;
330
+ const auto& start_clock = clock64();
331
+ if (cute::elect_one_sync())
332
+ while (clock64() - start_clock < kNumDummyCycles) {}
333
+ __syncwarp();
334
+ }
335
+ }
336
+ }
337
+
338
+ // To safely deconstruct barriers, we need another round of waits
339
+ const auto& iter_idx = scheduler.current_iter - 1;
340
+ if (kNumMulticast > 1 and iter_idx >= 0) {
341
+ const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1;
342
+ tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx);
343
+ }
344
+ } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) {
345
+ // Epilogue warp groups
346
+ const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32);
347
+
348
+ // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
349
+ // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
350
+ // NOTES: we also forbid two CTAs to share the same SM and its tensor memory
351
+ DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
352
+
353
+ // TMA checks
354
+ constexpr uint32_t kNumBankGroupBytes = 16;
355
+ constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t);
356
+ DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
357
+ DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
358
+
359
+ // Share store pipeline between blocks
360
+ uint32_t tma_stage_idx = 0;
361
+ auto advance_store_pipeline = [&]() {
362
+ tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages;
363
+ };
364
+
365
+ // Persistently schedule over blocks
366
+ while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
367
+ auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
368
+ auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
369
+
370
+ // Wait UMMA arrival
371
+ tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
372
+ tcgen05_after_thread_sync();
373
+
374
+ // Load from tensor memory into registers, and write shared memory with STSM
375
+ DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
376
+ DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
377
+
378
+ // Iterate over M waves
379
+ #pragma unroll
380
+ for (uint32_t w = 0; w < kNumMWaves; ++ w) {
381
+ // Issue every swizzled atom and pipeline STSM and TMA store
382
+ constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
383
+ #pragma unroll
384
+ for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) {
385
+ // Wait shared memory to be released
386
+ if (epilogue_warp_idx == 0)
387
+ cute::tma_store_wait<kNumTMAStoreStages - 1>();
388
+ cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
389
+
390
+ // The pipeline stage
391
+ const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M;
392
+ const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
393
+
394
+ // Store into shared memory
395
+ #pragma unroll
396
+ for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
397
+ // Calculate the index of the bank group to be written in the atom
398
+ auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
399
+
400
+ // Reshape the atom in another view and swizzle
401
+ // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
402
+ // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
403
+ // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
404
+ constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
405
+ auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
406
+ auto col = kHasShortcut ? (i) : (bank_group_index % 8);
407
+ col ^= row % (kSwizzleCDMode / 16);
408
+
409
+ // Source and destination memory address
410
+ uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
411
+ w * BLOCK_N + // Wave offset
412
+ s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
413
+ auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
414
+ epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
415
+ row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
416
+
417
+ // Load from tensor memory, store into shared memory
418
+ uint32_t values[kNumElemsPerBankGroup];
419
+ if constexpr (cute::is_same_v<cd_dtype_t, float>) {
420
+ // For FP32 output, read and store
421
+ DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
422
+ cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
423
+ values[0], values[1], values[2], values[3]);
424
+ cutlass::arch::fence_view_async_tmem_load();
425
+ st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
426
+ } else {
427
+ // For BF16 output, read, cast and store
428
+ DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
429
+ cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
430
+ values[0], values[1], values[2], values[3],
431
+ values[4], values[5], values[6], values[7]);
432
+ cutlass::arch::fence_view_async_tmem_load();
433
+ st_shared(smem_ptr,
434
+ cast_into_bf16_and_pack(values[0], values[1]),
435
+ cast_into_bf16_and_pack(values[2], values[3]),
436
+ cast_into_bf16_and_pack(values[4], values[5]),
437
+ cast_into_bf16_and_pack(values[6], values[7]));
438
+ }
439
+ }
440
+
441
+ // Notify tensor memory empty (only at the leader CTA) arrival ASAP
442
+ // NOTES: only the last stage needs to do this
443
+ if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
444
+ tcgen05_before_thread_sync();
445
+ tmem_empty_barriers[accum_stage_idx]->arrive(0u);
446
+ }
447
+ __syncwarp();
448
+
449
+ // Synchronize all threads and issue TMA
450
+ cute::tma_store_fence();
451
+ cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
452
+ if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
453
+ if constexpr (kGemmType == GemmType::Batched) {
454
+ using cute_tma_t = cute::conditional_t<kWithAccumulation,
455
+ cute::SM90_TMA_REDUCE_ADD_3D, cute::SM90_TMA_STORE_3D>;
456
+ cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx],
457
+ n_idx, m_idx, scheduler.current_group_idx);
458
+ } else {
459
+ using cute_tma_t = cute::conditional_t<kWithAccumulation,
460
+ cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
461
+ cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx);
462
+ }
463
+ cute::tma_store_arrive();
464
+ }
465
+ }
466
+ }
467
+ }
468
+
469
+ // Deallocate tensor memory by the last UMMA store warp
470
+ // NOTES: warp 0 is waiting TMA store
471
+ if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1)
472
+ Allocator().free(0, kNumTmemCols);
473
+ }
474
+ #else
475
+ if (blockIdx.x == 0 and threadIdx.x == 0)
476
+ DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
477
+ #endif
478
+ }
479
+
480
+ }; // namespace deep_gemm
481
+
482
+ #pragma clang diagnostic pop
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cute/arch/cluster_sm90.hpp>
4
+ #include <cute/util/type_traits.hpp>
5
+ #include <cutlass/arch/barrier.h>
6
+
7
+ #include <deep_gemm/common/utils.cuh>
8
+ #include <deep_gemm/common/sm100_utils.cuh>
9
+
10
+ namespace deep_gemm {
11
+
12
+ using namespace deep_gemm::sm100;
13
+
14
+ template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
15
+ uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
16
+ uint32_t kSplitFactor,
17
+ uint32_t kSwizzleABMode, uint32_t kSwizzleCDMode,
18
+ uint32_t kNumStages, uint32_t kNumThreads>
19
+ __global__ void __launch_bounds__(kNumThreads, 1)
20
+ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
21
+ const __grid_constant__ cute::TmaDescriptor tensor_map_a,
22
+ const __grid_constant__ cute::TmaDescriptor tensor_map_b,
23
+ const __grid_constant__ cute::TmaDescriptor tensor_map_d) {
24
+ #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
25
+ using Barrier = cutlass::arch::ClusterTransactionBarrier;
26
+
27
+ // Configs
28
+ constexpr uint32_t LAYOUT_AD_M = 128;
29
+ constexpr uint32_t kNumTMAStoreStages = 2;
30
+
31
+ // Utils
32
+ const auto warp_idx = cutlass::canonical_warp_idx_sync();
33
+ const auto lane_idx = get_lane_idx();
34
+ DG_STATIC_ASSERT(BLOCK_M == LAYOUT_AD_M and BLOCK_N == 128 and BLOCK_K == 64, "Invalid block size");
35
+ DG_STATIC_ASSERT(kSwizzleABMode == 128 and kSwizzleCDMode == 128, "Invalid swizzle mode");
36
+
37
+ // Align to 1024 bytes for swizzle-128B
38
+ extern __shared__ __align__(1024) uint8_t smem_buffer[];
39
+
40
+ // Shared memory sizes
41
+ constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = BLOCK_M * kSwizzleCDMode;
42
+ constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
43
+ constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t);
44
+ constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t);
45
+
46
+ // Prefetch TMA descriptors at the very beginning
47
+ if (warp_idx == 0 and cute::elect_one_sync()) {
48
+ cute::prefetch_tma_descriptor(&tensor_map_a);
49
+ cute::prefetch_tma_descriptor(&tensor_map_b);
50
+ cute::prefetch_tma_descriptor(&tensor_map_d);
51
+ }
52
+
53
+ // Real tensor memory size and offsets
54
+ constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<BLOCK_N>();
55
+
56
+ // Fill D/A/B
57
+ auto smem_cd = PatternVisitor([&](const uint32_t& i) {
58
+ return reinterpret_cast<float*>(smem_buffer + (i * SMEM_CD_SIZE_PER_STAGE));
59
+ });
60
+ auto smem_a = PatternVisitor([&](const uint32_t& i) {
61
+ return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
62
+ });
63
+ auto smem_b = PatternVisitor([&](const uint32_t& i) {
64
+ return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
65
+ });
66
+
67
+ // Fill barriers
68
+ auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
69
+ kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
70
+ auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
71
+ auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
72
+ auto tmem_full_barrier = barrier_start_ptr + (kNumStages * 2);
73
+
74
+ // Fill the tensor memory pointer
75
+ auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 2 + 1);
76
+ DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
77
+
78
+ // Initialize barriers
79
+ if (warp_idx == 1 and cute::elect_one_sync()) {
80
+ #pragma unroll
81
+ for (uint32_t i = 0; i < kNumStages; ++ i) {
82
+ full_barriers[i]->init(1);
83
+ empty_barriers[i]->init(1);
84
+ }
85
+ tmem_full_barrier->init(1);
86
+
87
+ // Make initialized barrier visible in async proxy
88
+ cutlass::arch::fence_barrier_init();
89
+ } else if (warp_idx == 2) {
90
+ // Allocate tensor memory
91
+ cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
92
+ }
93
+ __syncthreads();
94
+
95
+ // Block indices
96
+ const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N);
97
+ const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M);
98
+ const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks;
99
+ const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks;
100
+ const uint32_t n_block_idx = mn_block_idx % num_n_blocks;
101
+ const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
102
+ const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
103
+
104
+ if (warp_idx == 0) {
105
+ // TMA load warp
106
+ for (uint32_t s = 0; s < num_total_stages; ++ s) {
107
+ const auto& stage_idx = s % kNumStages;
108
+ empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
109
+
110
+ uint32_t m_idx = BLOCK_M * m_block_idx;
111
+ uint32_t n_idx = BLOCK_N * n_block_idx;
112
+ uint32_t sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K;
113
+ uint32_t k_idx = sk_idx % SHAPE_K;
114
+ uint32_t s_idx = sk_idx / SHAPE_K;
115
+
116
+ // Issue TMAs
117
+ if (cute::elect_one_sync()) {
118
+ tma_copy<BLOCK_K, BLOCK_M, kSwizzleABMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M);
119
+ tma_copy<BLOCK_K, BLOCK_N, kSwizzleABMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N);
120
+ }
121
+
122
+ // Arrive at full barriers
123
+ constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
124
+ if (cute::elect_one_sync())
125
+ full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes);
126
+ }
127
+ } else if (warp_idx == 1) {
128
+ // MMA issue warp
129
+ // NOTES: only the leader CTA will do this
130
+ // Make instruction descriptor
131
+ constexpr uint32_t UMMA_M = LAYOUT_AD_M;
132
+ constexpr uint32_t UMMA_N = BLOCK_N;
133
+ constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t);
134
+ auto instr_desc = cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
135
+
136
+ DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
137
+ auto a_desc = make_umma_desc<cute::UMMA::Major::K, BLOCK_M, BLOCK_K, kSwizzleABMode>(smem_a[0], 0, 0);
138
+ auto b_desc = make_umma_desc<cute::UMMA::Major::K, BLOCK_N, BLOCK_K, kSwizzleABMode>(smem_b[0], 0, 0);
139
+ uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
140
+ uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
141
+
142
+ // Checks for MMA instructions
143
+ // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
144
+ DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
145
+ (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
146
+ (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
147
+ "Invalid MMA instruction shape");
148
+
149
+ // Wait tensor memory empty barrier arrival
150
+ tcgen05_after_thread_sync();
151
+
152
+ // Launch MMAs
153
+ for (uint32_t s = 0; s < num_total_stages; ++ s) {
154
+ // Wait TMA arrival
155
+ const auto& stage_idx = s % kNumStages;
156
+ full_barriers[stage_idx]->wait((s / kNumStages) & 1);
157
+ tcgen05_after_thread_sync();
158
+
159
+ // Issue UMMA in the leader CTA
160
+ const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
161
+ const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, stage_idx);
162
+ const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, stage_idx);
163
+ if (cute::elect_one_sync()) {
164
+ #pragma unroll
165
+ for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
166
+ a_desc.lo = advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_M, kSwizzleABMode, cutlass::bfloat16_t>(a_desc_base_lo, 0, k * UMMA_K);
167
+ b_desc.lo = advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_N, kSwizzleABMode, cutlass::bfloat16_t>(b_desc_base_lo, 0, k * UMMA_K);
168
+ SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc);
169
+ }
170
+ }
171
+
172
+ // Commit to the mbarrier object
173
+ // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
174
+ cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
175
+ }
176
+ cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barrier));
177
+ }
178
+
179
+ // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
180
+ // i.e., no need for `tmem_ptr |= (warp_idx * 32) << 16`.
181
+ // NOTES: we also forbid two CTAs to share the same SM and its tensor memory
182
+ if (warp_idx == 2)
183
+ DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
184
+
185
+ // TMA checks
186
+ constexpr uint32_t kNumBankGroupBytes = 16;
187
+ constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float);
188
+ constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(float);
189
+ DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
190
+ DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
191
+
192
+ // Wait UMMA arrival
193
+ tmem_full_barrier->wait(0);
194
+ tcgen05_after_thread_sync();
195
+
196
+ // Load from tensor memory into registers, and write shared memory with STSM
197
+ DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
198
+
199
+ // Issue every swizzled atom and pipeline STSM and TMA store
200
+ constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
201
+ #pragma unroll
202
+ for (uint32_t s = 0; s < kNumStores; ++ s) {
203
+ // Wait shared memory to be released
204
+ if (s >= kNumTMAStoreStages) {
205
+ if (warp_idx == 0 and cute::elect_one_sync())
206
+ cute::tma_store_wait<kNumTMAStoreStages - 1>();
207
+ cutlass::arch::NamedBarrier(kNumThreads).sync();
208
+ }
209
+
210
+ // The pipeline stage
211
+ const auto tma_stage_idx = s % kNumTMAStoreStages;
212
+ const auto m_idx = m_block_idx * BLOCK_M;
213
+ const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
214
+
215
+ // Store into shared memory
216
+ #pragma unroll
217
+ for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
218
+ // Calculate the index of the bank group to be written in the atom
219
+ auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
220
+
221
+ // Reshape the atom in another view and swizzle
222
+ // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
223
+ // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
224
+ // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
225
+ constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
226
+ auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
227
+ auto col = kHasShortcut ? (i) : (bank_group_index % 8);
228
+ col ^= row % (kSwizzleCDMode / 16);
229
+
230
+ // Source and destination memory address
231
+ uint32_t tmem_addr = s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
232
+ auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
233
+ warp_idx * 32 * kSwizzleCDMode + // Warp offset
234
+ row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
235
+
236
+ // Load from tensor memory, store into shared memory
237
+ uint32_t values[kNumElemsPerBankGroup];
238
+ DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
239
+ cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
240
+ values[0], values[1], values[2], values[3]);
241
+ cutlass::arch::fence_view_async_tmem_load();
242
+ st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
243
+ }
244
+
245
+ // Synchronize all threads and issue TMA
246
+ cute::tma_store_fence();
247
+ cutlass::arch::NamedBarrier(kNumThreads).sync();
248
+ if (warp_idx == 0 and cute::elect_one_sync()) {
249
+ cute::SM90_TMA_REDUCE_ADD_2D::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
250
+ cute::tma_store_arrive();
251
+ }
252
+ }
253
+
254
+ // Deallocate tensor memory by warp 1
255
+ // NOTES: warp 0 is doing TMA stores
256
+ if (warp_idx == 1)
257
+ cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
258
+
259
+ #else
260
+ if (blockIdx.x == 0 and threadIdx.x == 0)
261
+ DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
262
+ #endif
263
+ }
264
+
265
+ }
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #pragma clang diagnostic push
3
+ #pragma clang diagnostic ignored "-Wunknown-attributes"
4
+
5
+ #include <cutlass/arch/barrier.h>
6
+
7
+ #include <deep_gemm/common/epilogue_utils.cuh>
8
+ #include <deep_gemm/common/scheduler.cuh>
9
+ #include <deep_gemm/common/utils.cuh>
10
+ #include <deep_gemm/common/sm100_utils.cuh>
11
+
12
+ namespace deep_gemm {
13
+
14
+ using namespace deep_gemm::sm100;
15
+
16
+ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
17
+ uint32_t kGranKA, uint32_t kGranKB,
18
+ uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
19
+ uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
20
+ uint32_t kNumGroups,
21
+ uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleCDMode,
22
+ uint32_t kNumStages,
23
+ uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
24
+ uint32_t kNumMulticast, bool kIsMulticastOnA,
25
+ uint32_t kNumSMs,
26
+ GemmType kGemmType, bool kWithAccumulation,
27
+ typename a_dtype_t, typename b_dtype_t, typename cd_dtype_t,
28
+ typename epilogue_type_t>
29
+ __global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
30
+ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
31
+ uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
32
+ const __grid_constant__ cute::TmaDescriptor tensor_map_a,
33
+ const __grid_constant__ cute::TmaDescriptor tensor_map_b,
34
+ const __grid_constant__ cute::TmaDescriptor tensor_map_sfa,
35
+ const __grid_constant__ cute::TmaDescriptor tensor_map_sfb,
36
+ const __grid_constant__ cute::TmaDescriptor tensor_map_cd) {
37
+ #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
38
+ using Barrier = cutlass::arch::ClusterTransactionBarrier;
39
+ using Allocator = cute::conditional_t<kNumMulticast == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
40
+
41
+ // GEMM with accumulation must have FP32 output
42
+ if constexpr (kWithAccumulation)
43
+ DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
44
+
45
+ // Configs
46
+ constexpr uint32_t LAYOUT_AD_M = 128;
47
+ constexpr uint32_t WAVE_BLOCK_M = cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
48
+ constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M;
49
+ constexpr uint32_t kNumTMAStoreStages = 2;
50
+ constexpr uint32_t kNumUTCCPAlignedElems = 128;
51
+ DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K");
52
+ DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M");
53
+
54
+ constexpr uint32_t kNumSFAStagesPerLoad = kGranKA == 32 ? 1 : 4;
55
+ constexpr uint32_t kNumSFBStagesPerLoad = kGranKB == 32 ? 1 : 4;
56
+ DG_STATIC_ASSERT(kGranKA == 32 or kGranKA == 128, "Invalid granularity K for A");
57
+ DG_STATIC_ASSERT(kGranKB == 32 or kGranKB == 128, "Invalid granularity K for B");
58
+
59
+ // Overwrite shape constants if the compiler gives
60
+ shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
61
+ shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
62
+ shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
63
+ const uint32_t shape_sfa_k = ceil_div(shape_k, kGranKA * 4);
64
+ const uint32_t shape_sfb_k = ceil_div(shape_k, kGranKB * 4);
65
+
66
+ // Utils
67
+ bool is_leader_cta = cute::block_rank_in_cluster() == 0;
68
+ const auto warp_idx = cutlass::canonical_warp_idx_sync();
69
+ const auto lane_idx = get_lane_idx();
70
+
71
+ // Align to 1024 bytes for swizzle-128B
72
+ extern __shared__ __align__(1024) uint8_t smem_buffer[];
73
+
74
+ // 2-CTA MMA
75
+ constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
76
+ constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
77
+ constexpr uint32_t STORE_BLOCK_M = cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
78
+ constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t);
79
+ constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M;
80
+ DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast");
81
+ DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D");
82
+ DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast");
83
+ DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M");
84
+
85
+ // Share memory sizes
86
+ constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode;
87
+ constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
88
+ constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t);
89
+ constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t);
90
+ constexpr uint32_t SF_BLOCK_M = constexpr_align(BLOCK_M, kNumUTCCPAlignedElems);
91
+ constexpr uint32_t SF_BLOCK_N = constexpr_align(BLOCK_N, kNumUTCCPAlignedElems);
92
+ constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t);
93
+ constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t);
94
+ DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0,
95
+ "Shared memory of A/B must be aligned to 1024 bytes");
96
+ DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages");
97
+
98
+ // NOTES: Make sure we have enough shared memory for UMMA padding
99
+ static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(a_dtype_t);
100
+ DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA");
101
+
102
+ // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size
103
+ // TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2`
104
+ constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32;
105
+ constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32;
106
+ constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N + kNumSFATmemCols + kNumSFBTmemCols) > 512 ? 1 : 2;
107
+
108
+ // Real tensor memory size and offsets
109
+ constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N;
110
+ constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFATmemCols + kNumSFBTmemCols>();
111
+ constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols;
112
+ constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols;
113
+
114
+ // Prefetch TMA descriptors at the very beginning
115
+ if (warp_idx == 0 and cute::elect_one_sync()) {
116
+ cute::prefetch_tma_descriptor(&tensor_map_a);
117
+ cute::prefetch_tma_descriptor(&tensor_map_b);
118
+ cute::prefetch_tma_descriptor(&tensor_map_sfa);
119
+ cute::prefetch_tma_descriptor(&tensor_map_sfb);
120
+ cute::prefetch_tma_descriptor(&tensor_map_cd);
121
+ }
122
+
123
+ // D/A/B shared memory
124
+ auto smem_cd = PatternVisitor([&](const uint32_t& i) {
125
+ return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
126
+ });
127
+ auto smem_a = PatternVisitor([&](const uint32_t& i) {
128
+ return reinterpret_cast<a_dtype_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
129
+ });
130
+ auto smem_b = PatternVisitor([&](const uint32_t& i) {
131
+ return reinterpret_cast<b_dtype_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
132
+ });
133
+
134
+ // SFA/SFB shared memory
135
+ auto sf_start_ptr = smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
136
+ auto smem_sfa = PatternVisitor([=](const uint32_t& i) {
137
+ return reinterpret_cast<uint32_t*>(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE);
138
+ });
139
+ auto smem_sfb = PatternVisitor([=](const uint32_t& i) {
140
+ return reinterpret_cast<uint32_t*>(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE);
141
+ });
142
+
143
+ // Fill barriers
144
+ auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer +
145
+ SMEM_CD_SIZE +
146
+ kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) +
147
+ kNumStages * (SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE));
148
+ auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
149
+ auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
150
+ auto with_sf_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
151
+ auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); });
152
+ auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); });
153
+
154
+ // Fill the tensor memory pointer
155
+ auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2);
156
+ DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
157
+
158
+ // Initialize barriers
159
+ if (warp_idx == 1 and cute::elect_one_sync()) {
160
+ #pragma unroll
161
+ for (uint32_t i = 0; i < kNumStages; ++ i) {
162
+ // Arrive at all CTAs
163
+ full_barriers[i]->init(1);
164
+ empty_barriers[i]->init(1);
165
+ // Arrive only at the leader CTA
166
+ with_sf_full_barriers[i]->init(kNumMulticast * 32);
167
+ }
168
+ #pragma unroll
169
+ for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) {
170
+ // Arrive at all CTAs
171
+ tmem_full_barriers[i]->init(1);
172
+ // Arrive only at the leader CTA
173
+ tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads);
174
+ }
175
+
176
+ // Make initialized barrier visible in async proxy
177
+ cutlass::arch::fence_barrier_init();
178
+ } else if (warp_idx == 2) {
179
+ // Allocate tensor memory
180
+ Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
181
+ }
182
+ kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
183
+
184
+ // Block scheduler
185
+ uint32_t m_block_idx, n_block_idx;
186
+ auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
187
+
188
+ // Pipeline and TMA phases
189
+ uint32_t stage_idx = 0, phase = 0;
190
+ auto advance_pipeline = [&](uint32_t& k_block_idx) {
191
+ ++ k_block_idx;
192
+
193
+ // Flip phases only if reach the next first stage
194
+ stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1;
195
+ phase ^= stage_idx == 0;
196
+ };
197
+
198
+ // Dispatch warps into different roles
199
+ if (warp_idx == 0 and cute::elect_one_sync()) {
200
+ // TMA load warp
201
+ // Persistently schedule over blocks
202
+ while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
203
+ const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
204
+ for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
205
+ // Wait consumer release
206
+ empty_barriers[stage_idx]->wait(phase ^ 1);
207
+
208
+ // Compute offsets
209
+ // NOTES: the group is always concatenated with the outer dimension
210
+ uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> (
211
+ shape_m, BLOCK_M, m_block_idx);
212
+ uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> (
213
+ shape_n, BLOCK_N, n_block_idx, m_block_idx);
214
+
215
+ // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
216
+ // And for all m-grouped GEMMs, A must be K-majored
217
+ DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or
218
+ kMajorA == cute::UMMA::Major::K, "Invalid major");
219
+ uint32_t k_idx = k_block_idx * BLOCK_K;
220
+ uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> (
221
+ shape_k, BLOCK_K, k_block_idx, m_block_idx);
222
+ uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> (
223
+ shape_k, BLOCK_K, k_block_idx, m_block_idx);
224
+
225
+ // Add 2 CTA offsets
226
+ if constexpr (kNumMulticast > 1) {
227
+ m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
228
+ n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
229
+ }
230
+
231
+ // Issue TMAs
232
+ constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
233
+ const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
234
+ if constexpr (kMajorA == cute::UMMA::Major::K)
235
+ tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t, kIsBatchedMM>(
236
+ &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, 1, batch_idx);
237
+ if constexpr (kMajorA == cute::UMMA::Major::MN)
238
+ tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, a_dtype_t, kIsBatchedMM>(
239
+ &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, 1, batch_idx);
240
+ if constexpr (kMajorB == cute::UMMA::Major::K)
241
+ tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t, kIsBatchedMM>(
242
+ &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, 1, batch_idx);
243
+ if constexpr (kMajorB == cute::UMMA::Major::MN)
244
+ tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, b_dtype_t, kIsBatchedMM>(
245
+ &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, 1, batch_idx);
246
+ auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE / (std::is_same_v<a_dtype_t, cutlass::float_e4m3_t> ? 1 : 2) +
247
+ SMEM_B_SIZE_PER_STAGE / (std::is_same_v<b_dtype_t, cutlass::float_e4m3_t> ? 1 : 2);
248
+
249
+ // Issue SFA and SFB TMAs at certain stages
250
+ // No swizzling, so one TMA for one SF is enough
251
+ if (k_block_idx % kNumSFAStagesPerLoad == 0) {
252
+ tma_copy<BLOCK_M, 1, 0>(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], m_block_idx * BLOCK_M,
253
+ scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::SF_K>(shape_sfa_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad)));
254
+ num_arrival_bytes += BLOCK_M * sizeof(uint32_t);
255
+ }
256
+ if (k_block_idx % kNumSFBStagesPerLoad == 0) {
257
+ tma_copy<BLOCK_N, 1, 0>(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], n_block_idx * BLOCK_N,
258
+ scheduler.template get_global_idx<true, IndexType::SF_K>(shape_sfb_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx));
259
+ num_arrival_bytes += BLOCK_N * sizeof(uint32_t);
260
+ }
261
+
262
+ // Arrive at full barriers
263
+ full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes);
264
+ }
265
+ }
266
+ } else if (warp_idx == 1 and is_leader_cta) {
267
+ // MMA issue warp
268
+ // NOTES: only the leader CTA will do this
269
+ // Make instruction descriptor
270
+ // TODO: refactor `UMMA_M` calculation
271
+ constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast);
272
+ constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1);
273
+ constexpr uint32_t UMMA_K = 32;
274
+ auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<a_dtype_t, b_dtype_t, float, cutlass::float_ue8m0_t,
275
+ UMMA_M, UMMA_N, kMajorA, kMajorB>();
276
+ auto sf_desc = make_sf_desc(nullptr);
277
+
278
+ DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
279
+ auto a_desc = make_umma_desc<kMajorA, LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode>(smem_a[0], 0, 0);
280
+ auto b_desc = make_umma_desc<kMajorB, LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode>(smem_b[0], 0, 0);
281
+ uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
282
+ uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
283
+
284
+ // Checks for MMA instructions
285
+ // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
286
+ DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
287
+ (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
288
+ (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
289
+ "Invalid MMA instruction shape");
290
+
291
+ // Persistently schedule over blocks
292
+ while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
293
+ // Wait tensor memory empty barrier arrival
294
+ auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
295
+ auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
296
+ tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
297
+ tcgen05_after_thread_sync();
298
+
299
+ // Empty barrier arrival
300
+ auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) {
301
+ auto umma_arrive = [](const uint64_t* barrier) {
302
+ if constexpr (kNumMulticast == 1) {
303
+ cutlass::arch::umma_arrive(barrier);
304
+ } else {
305
+ constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
306
+ cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
307
+ }
308
+ };
309
+ umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
310
+
311
+ // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
312
+ if (do_tmem_full_arrive)
313
+ umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
314
+ };
315
+
316
+ // Launch MMAs
317
+ const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
318
+ for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
319
+ // Wait TMA and SF-transpose arrival
320
+ with_sf_full_barriers[stage_idx]->wait(phase);
321
+ tcgen05_after_thread_sync();
322
+
323
+ // Do SF copy at certain stages
324
+ // NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves
325
+ // TODO: process shared memory descriptor by addition
326
+ using cute_utccp_t = cute::conditional_t<kNumMulticast == 1,
327
+ cute::SM100_UTCCP_4x32dp128bit_1cta, cute::SM100_UTCCP_4x32dp128bit_2cta>;
328
+ const uint32_t sfa_stage_in_group_idx = k_block_idx % kNumSFAStagesPerLoad;
329
+ if (sfa_stage_in_group_idx == 0 and cute::elect_one_sync()) {
330
+ #pragma unroll
331
+ for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
332
+ auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems;
333
+ replace_smem_desc_addr(sf_desc, smem_ptr);
334
+ cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
335
+ }
336
+ }
337
+ const uint32_t sfb_stage_in_group_idx = k_block_idx % kNumSFBStagesPerLoad;
338
+ if (sfb_stage_in_group_idx == 0 and cute::elect_one_sync()) {
339
+ #pragma unroll
340
+ for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
341
+ auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems;
342
+ replace_smem_desc_addr(sf_desc, smem_ptr);
343
+ cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4);
344
+ }
345
+ }
346
+ __syncwarp();
347
+
348
+ // Issue UMMA in the leader CTA
349
+ using mma_t = cute::conditional_t<kNumMulticast == 1, SM100_MMA_MXF8F6F4_SS, SM100_MMA_MXF8F6F4_2x1SM_SS>;
350
+ const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast<int>(stage_idx));
351
+ const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
352
+ if (cute::elect_one_sync()) {
353
+ #pragma unroll
354
+ for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
355
+ const uint32_t sfa_id = (kGranKA == 32 ? k : sfa_stage_in_group_idx);
356
+ const uint32_t sfb_id = (kGranKB == 32 ? k : sfb_stage_in_group_idx);
357
+ const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sfa_id, sfb_id);
358
+
359
+ b_desc.lo = advance_umma_desc_lo<kMajorB, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K);
360
+ #pragma unroll
361
+ for (uint32_t w = 0; w < kNumMWaves; ++ w) {
362
+ DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset");
363
+ a_desc.lo = advance_umma_desc_lo<kMajorA, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, w * WAVE_BLOCK_M * BLOCK_K, k * UMMA_K);
364
+ mma_t::fma(a_desc, b_desc,
365
+ accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
366
+ k_block_idx > 0 or k > 0,
367
+ runtime_instr_desc,
368
+ kTmemStartColOfSFA + w * (kNumUTCCPAlignedElems / 32),
369
+ kTmemStartColOfSFB);
370
+ }
371
+ }
372
+ }
373
+
374
+ // Commit to the mbarrier object
375
+ // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
376
+ empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1);
377
+ }
378
+ }
379
+
380
+ // To safely deconstruct barriers, we need another round of waits
381
+ const auto& iter_idx = scheduler.current_iter - 1;
382
+ if (kNumMulticast > 1 and iter_idx >= 0) {
383
+ const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1;
384
+ tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx);
385
+ }
386
+ } else if (warp_idx == 2) {
387
+ // UTCCP transposer
388
+ auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) {
389
+ DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements");
390
+ uint32_t values[4];
391
+ #pragma unroll
392
+ for (uint32_t i = 0; i < 4; ++ i)
393
+ values[i] = ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx);
394
+ __syncwarp();
395
+ #pragma unroll
396
+ for (uint32_t i = 0; i < 4; ++ i)
397
+ st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]);
398
+ };
399
+
400
+ while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
401
+ const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
402
+ for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
403
+ // Wait TMA arrival
404
+ full_barriers[stage_idx]->wait(phase);
405
+
406
+ // Transpose for UTCCP at certain stages
407
+ if (k_block_idx % kNumSFAStagesPerLoad == 0) {
408
+ #pragma unroll
409
+ for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i)
410
+ utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems);
411
+ // TODO: figure out whether the proxy fence is valid for 2-CTA cases
412
+ cutlass::arch::fence_view_async_shared();
413
+ }
414
+ if (k_block_idx % kNumSFBStagesPerLoad == 0) {
415
+ #pragma unroll
416
+ for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i)
417
+ utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems);
418
+ // TODO: figure out whether the proxy fence is valid for 2-CTA cases
419
+ cutlass::arch::fence_view_async_shared();
420
+ }
421
+
422
+ // Arrive
423
+ with_sf_full_barriers[stage_idx]->arrive(0u);
424
+ }
425
+ }
426
+ } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) {
427
+ // Epilogue warp groups
428
+ const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32);
429
+
430
+ // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
431
+ // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
432
+ // NOTES: we also forbid two CTAs to share the same SM and its tensor memory
433
+ DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
434
+
435
+ // TMA checks
436
+ constexpr uint32_t kNumBankGroupBytes = 16;
437
+ constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t);
438
+ DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
439
+ DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
440
+
441
+ // Share store pipeline between blocks
442
+ uint32_t tma_stage_idx = 0;
443
+ auto advance_store_pipeline = [&]() {
444
+ tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages;
445
+ };
446
+
447
+ // Persistently schedule over blocks
448
+ while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
449
+ auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
450
+ auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
451
+
452
+ // Wait UMMA arrival
453
+ tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
454
+ tcgen05_after_thread_sync();
455
+
456
+ // Load from tensor memory into registers, and write shared memory with STSM
457
+ DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
458
+ DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
459
+
460
+ // Iterate over M waves
461
+ #pragma unroll
462
+ for (uint32_t w = 0; w < kNumMWaves; ++ w) {
463
+ // Issue every swizzled atom and pipeline STSM and TMA store
464
+ constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
465
+ #pragma unroll
466
+ for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) {
467
+ // Wait shared memory to be released
468
+ if (epilogue_warp_idx == 0)
469
+ cute::tma_store_wait<kNumTMAStoreStages - 1>();
470
+ cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
471
+
472
+ // The pipeline stage
473
+ const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M;
474
+ const auto n_idx = epilogue_type_t::apply_index_n<STORE_BLOCK_N>(n_block_idx * BLOCK_N + s * STORE_BLOCK_N);
475
+
476
+ // Store into shared memory
477
+ #pragma unroll
478
+ for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
479
+ // Calculate the index of the bank group to be written in the atom
480
+ auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
481
+
482
+ // Reshape the atom in another view and swizzle
483
+ // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
484
+ // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
485
+ // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
486
+ constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
487
+ auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
488
+ auto col = kHasShortcut ? (i) : (bank_group_index % 8);
489
+ col ^= row % (kSwizzleCDMode / 16);
490
+
491
+ // Source and destination memory address
492
+ uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
493
+ w * BLOCK_N + // Wave offset
494
+ s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
495
+ auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
496
+ epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
497
+ row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
498
+
499
+ // Load from tensor memory, store into shared memory
500
+ uint32_t values[kNumElemsPerBankGroup];
501
+ if constexpr (cute::is_same_v<cd_dtype_t, float>) {
502
+ // For FP32 output, read and store
503
+ DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
504
+ cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
505
+ values[0], values[1], values[2], values[3]);
506
+ cutlass::arch::fence_view_async_tmem_load();
507
+ st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
508
+ } else {
509
+ // For BF16 output, read, cast and store
510
+ DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
511
+ cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
512
+ values[0], values[1], values[2], values[3],
513
+ values[4], values[5], values[6], values[7]);
514
+ cutlass::arch::fence_view_async_tmem_load();
515
+ st_shared(smem_ptr,
516
+ cast_into_bf16_and_pack(values[0], values[1]),
517
+ cast_into_bf16_and_pack(values[2], values[3]),
518
+ cast_into_bf16_and_pack(values[4], values[5]),
519
+ cast_into_bf16_and_pack(values[6], values[7]));
520
+ }
521
+ }
522
+
523
+ // Notify tensor memory empty (only at the leader CTA) arrival ASAP
524
+ // NOTES: only the last stage needs to do this
525
+ if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
526
+ tcgen05_before_thread_sync();
527
+ tmem_empty_barriers[accum_stage_idx]->arrive(0u);
528
+ }
529
+
530
+ // Synchronize all threads and issue TMA
531
+ cute::tma_store_fence();
532
+ cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
533
+ if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
534
+ if constexpr (kGemmType == GemmType::Batched) {
535
+ using cute_tma_t = cute::conditional_t<kWithAccumulation,
536
+ cute::SM90_TMA_REDUCE_ADD_3D, cute::SM90_TMA_STORE_3D>;
537
+ cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx],
538
+ n_idx, m_idx, scheduler.current_group_idx);
539
+ } else {
540
+ using cute_tma_t = cute::conditional_t<kWithAccumulation,
541
+ cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
542
+ cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx);
543
+ }
544
+ cute::tma_store_arrive();
545
+ }
546
+ }
547
+ }
548
+ }
549
+
550
+ // Deallocate tensor memory by the last UMMA store warp
551
+ // NOTES: warp 0 is waiting TMA store
552
+ if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1)
553
+ Allocator().free(0, kNumTmemCols);
554
+ }
555
+ #else
556
+ if (blockIdx.x == 0 and threadIdx.x == 0)
557
+ DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
558
+ #endif
559
+ }
560
+
561
+ }; // namespace deep_gemm
562
+
563
+ #pragma clang diagnostic pop
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cutlass/arch/barrier.h>
4
+ #include <cutlass/arch/reg_reconfig.h>
5
+
6
+ #include <cute/arch/cluster_sm90.hpp>
7
+ #include <cute/arch/copy_sm90_desc.hpp>
8
+
9
+ #include <deep_gemm/common/utils.cuh>
10
+ #include <deep_gemm/common/sm90_utils.cuh>
11
+ #include <deep_gemm/common/sm100_utils.cuh>
12
+
13
+ namespace deep_gemm {
14
+
15
+ using namespace deep_gemm::sm90;
16
+ using namespace deep_gemm::sm100;
17
+
18
+ template <uint32_t kNumHeads, uint32_t kHeadDim,
19
+ bool kIsCompressedLogits,
20
+ uint32_t BLOCK_Q, uint32_t BLOCK_KV,
21
+ uint32_t kNumQStages, uint32_t kNumKVStages,
22
+ uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
23
+ uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
24
+ __global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
25
+ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
26
+ const uint32_t max_seqlen_k, const uint64_t stride_logits,
27
+ uint32_t* cu_seq_len_k_start,
28
+ uint32_t* cu_seq_len_k_end,
29
+ float* logits,
30
+ const __grid_constant__ cute::TmaDescriptor tensor_map_q,
31
+ const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
32
+ const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
33
+ const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
34
+ // TODO: consider TMA multicast
35
+ // Normally, `h (kNumHeads) == 32` and `d (kHeadDim) == 64`
36
+ // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]`
37
+ // Q should be load only at once for a block
38
+ const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q);
39
+
40
+ // Types
41
+ using Barrier = cutlass::arch::ClusterTransactionBarrier;
42
+
43
+ // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
44
+ const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
45
+ const auto& warp_in_group_idx = warp_idx % 4;
46
+ const auto& warpgroup_idx = warp_idx / 4;
47
+ const auto& lane_idx = get_lane_idx();
48
+
49
+ // Prefetch TMA descriptors
50
+ DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
51
+ if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
52
+ cute::prefetch_tma_descriptor(&tensor_map_q);
53
+ cute::prefetch_tma_descriptor(&tensor_map_kv);
54
+ cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
55
+ cute::prefetch_tma_descriptor(&tensor_map_weights);
56
+ }
57
+ __syncwarp();
58
+
59
+ // Shared memory configs
60
+ // NOTES: weight may be unaligned
61
+ static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
62
+ static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float);
63
+ static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
64
+ static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
65
+ static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u);
66
+
67
+ // Align to 512 bytes for swizzle-64B
68
+ extern __shared__ __align__(512) uint8_t smem_buffer[];
69
+ DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling");
70
+ DG_STATIC_ASSERT(SMEM_WEIGHT_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling");
71
+ DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling");
72
+
73
+ // TMA configs
74
+ constexpr uint32_t kNumTmemCols = BLOCK_Q * kNumHeads * kNumMathWarpGroups;
75
+ DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
76
+
77
+ // Data on shared memory
78
+ auto smem_q = PatternVisitor([&](const uint32_t& i) {
79
+ return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer +
80
+ SMEM_Q_SIZE_PER_STAGE * i);
81
+ });
82
+ auto smem_weights = PatternVisitor([&](const uint32_t& i) {
83
+ return reinterpret_cast<float*>(smem_buffer +
84
+ SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
85
+ });
86
+ auto smem_kv = PatternVisitor([&](const uint32_t& i) {
87
+ return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (
88
+ SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i));
89
+ });
90
+ auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
91
+ return reinterpret_cast<float*>(smem_buffer +
92
+ SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages +
93
+ SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i);
94
+ });
95
+
96
+ // TMA barriers
97
+ auto barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
98
+ auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
99
+ auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); });
100
+ auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); });
101
+ auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); });
102
+ auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); });
103
+ auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); });
104
+
105
+ // Tensor memory allocation
106
+ auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_ptr + kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups * 2);
107
+
108
+ // Initialize barriers
109
+ DG_STATIC_ASSERT(kNumSpecializedThreads % 128 == 0 and kNumSpecializedThreads >= 64, "Invalid threads");
110
+ const bool& is_tma_load_warp = (warp_idx == (kNumMathThreads / 32));
111
+ const bool& is_umma_warp = (warp_idx == (kNumMathThreads / 32 + 1));
112
+ if (is_tma_load_warp and cute::elect_one_sync()) {
113
+ #pragma unroll
114
+ for (uint32_t i = 0; i < kNumQStages; ++ i) {
115
+ full_q_barriers[i]->init(1);
116
+ empty_q_barriers[i]->init(kNumMathThreads);
117
+ }
118
+ #pragma unroll
119
+ for (uint32_t i = 0; i < kNumKVStages; ++ i) {
120
+ full_kv_barriers[i]->init(1);
121
+ empty_kv_barriers[i]->init(kNumMathThreads);
122
+ }
123
+ #pragma unroll
124
+ for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
125
+ full_umma_barriers[i]->init(1);
126
+ empty_umma_barriers[i]->init(128);
127
+ }
128
+
129
+ // Make initialized barrier visible in async proxy
130
+ cutlass::arch::fence_barrier_init();
131
+ } else if (is_umma_warp) {
132
+ // Allocate tensor memory
133
+ cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
134
+ }
135
+ __syncthreads();
136
+
137
+ // Register reconfigurations
138
+ constexpr uint32_t kNumSpecializedRegisters = 24;
139
+ constexpr uint32_t kNumMathRegisters = 240;
140
+
141
+ // Block scheduler
142
+ uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0;
143
+ const auto& get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
144
+ return {block_q_idx + gridDim.x, q_iter_idx + 1};
145
+ };
146
+ uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
147
+ const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
148
+ uint32_t start = cute::numeric_limits<uint32_t>::max();
149
+ uint32_t end = cute::numeric_limits<uint32_t>::min();
150
+
151
+ #pragma unroll
152
+ for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
153
+ const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
154
+ seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx);
155
+ seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx);
156
+ start = min(start, min(seq_k_start[i], seq_len_kv));
157
+ end = max(end, min(seq_k_end[i], seq_len_kv));
158
+ }
159
+ start = start / 4 * 4;
160
+ return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
161
+ ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase
162
+ start, ceil_div(end - start, BLOCK_KV)}; // Task info
163
+ };
164
+
165
+ // KV pipeline
166
+ uint32_t num_total_kv_blocks = 0;
167
+ const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple<uint32_t, uint32_t> {
168
+ return {
169
+ (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage
170
+ ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase
171
+ };
172
+ };
173
+
174
+ // UMMA settings
175
+ // Construct instruction with layout D
176
+ constexpr uint32_t UMMA_M = 128;
177
+ constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
178
+ constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads;
179
+
180
+ if (is_tma_load_warp) {
181
+ cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
182
+
183
+ // Prefetch
184
+ const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) {
185
+ tma_copy<kHeadDim, BLOCK_Q * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads);
186
+ tma_copy<kNumHeads, BLOCK_Q, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q);
187
+ full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
188
+ };
189
+ if (cute::elect_one_sync() and block_q_idx < num_q_blocks)
190
+ issue_tma_q(0, block_q_idx);
191
+
192
+ // Only the first lane persistently schedules over blocks
193
+ if (cute::elect_one_sync()) {
194
+ while (block_q_idx < num_q_blocks) {
195
+ CUTE_TIE_DECL(load_schedule(1), q_stage_idx, q_phase, kv_start, num_kv_blocks);
196
+
197
+ // Wait Q consumer release
198
+ empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
199
+
200
+ // Issue TMA Q
201
+ if (const auto& next_block_q_idx = cute::get<0>(get_next_block_q_idx()); next_block_q_idx < num_q_blocks)
202
+ issue_tma_q(q_stage_idx, next_block_q_idx);
203
+
204
+ // Issue TMA KV
205
+ #pragma unroll
206
+ for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) {
207
+ // Wait consumer release
208
+ CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase);
209
+ empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
210
+
211
+ // Issue TMA KV
212
+ tma_copy<kHeadDim, BLOCK_KV, kHeadDim>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
213
+ smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV);
214
+ tma_copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
215
+ smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0);
216
+ full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
217
+ }
218
+ num_total_kv_blocks += num_kv_blocks;
219
+
220
+ // Jump to the next block
221
+ CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
222
+ }
223
+ }
224
+ } else if (is_umma_warp) {
225
+ cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
226
+
227
+ // Require full allocation
228
+ DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
229
+
230
+ // Make UMMA desc
231
+ auto instr_desc = cute::UMMA::make_instr_desc<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
232
+ UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
233
+ auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
234
+
235
+ while (block_q_idx < num_q_blocks) {
236
+ CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks);
237
+
238
+ // Wait TMA Q arrival
239
+ full_q_barriers[q_stage_idx]->wait(q_phase);
240
+
241
+ // Compute over KV blocks
242
+ #pragma unroll
243
+ for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) {
244
+ // Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]`
245
+ // Wait TMA KV arrival
246
+ CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase);
247
+ full_kv_barriers[kv_stage_idx]->wait(kv_phase);
248
+
249
+ // Issue UMMA
250
+ DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads, "Invalid block size");
251
+ DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim");
252
+ #pragma unroll
253
+ for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
254
+ empty_umma_barriers[i]->wait(((num_total_kv_blocks + kv_block_idx) & 1) ^ 1);
255
+ tcgen05_after_thread_sync();
256
+ #pragma unroll
257
+ for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
258
+ auto a_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
259
+ smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K);
260
+ auto b_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
261
+ smem_q[q_stage_idx], 0, k * UMMA_K);
262
+ cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc);
263
+ }
264
+ cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(full_umma_barriers[i]));
265
+ }
266
+ }
267
+ num_total_kv_blocks += num_kv_blocks;
268
+
269
+ // Jump to the next block
270
+ CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
271
+ }
272
+ } else if (warp_idx >= kNumMathThreads / 32) {
273
+ cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
274
+ } else if (warp_idx < kNumMathThreads / 32) {
275
+ cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
276
+
277
+ // Offsets
278
+ const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0);
279
+ const auto& warp_offset = warp_idx * 32;
280
+ const auto& v_offset = lane_idx;
281
+
282
+ // Preload weights
283
+ constexpr uint32_t kNumWeightsInReg = cute::min(52, kNumHeads);
284
+ float weights[BLOCK_Q][kNumWeightsInReg];
285
+ DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers");
286
+
287
+ while (block_q_idx < num_q_blocks) {
288
+ CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks);
289
+
290
+ // Wait TMA Q arrival
291
+ full_q_barriers[q_stage_idx]->wait(q_phase);
292
+
293
+ // Read weights
294
+ #pragma unroll
295
+ for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
296
+ for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) {
297
+ weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
298
+ }
299
+ }
300
+
301
+ // Compute over KV blocks
302
+ #pragma unroll
303
+ for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) {
304
+ // Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]`
305
+ // Wait TMA KV arrival
306
+ CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase);
307
+ full_kv_barriers[kv_stage_idx]->wait(kv_phase);
308
+
309
+ // Read per-KV scales
310
+ float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_offset);
311
+
312
+ // Wait UMMA arrival
313
+ full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1);
314
+ tcgen05_after_thread_sync();
315
+
316
+ // Release KV empty
317
+ empty_kv_barriers[kv_stage_idx]->arrive();
318
+
319
+ // Reduce over the head dim and store
320
+ const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset;
321
+ static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2;
322
+ DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
323
+
324
+ constexpr uint32_t kNumLDTMElems = kNumHeads * BLOCK_Q;
325
+ DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid kNumLDTMElems");
326
+ uint32_t shifted_accum[kNumLDTMElems];
327
+ auto tmem_load = [&](auto... Is) {
328
+ if constexpr (kNumLDTMElems == 32) {
329
+ cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...);
330
+ } else if constexpr (kNumLDTMElems == 64) {
331
+ cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...);
332
+ } else if constexpr (kNumLDTMElems == 128) {
333
+ cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...);
334
+ }
335
+ };
336
+ [&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
337
+ cutlass::arch::fence_view_async_tmem_load();
338
+
339
+ tcgen05_before_thread_sync();
340
+ empty_umma_barriers[warpgroup_idx]->arrive();
341
+
342
+ #pragma unroll
343
+ for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
344
+ auto accum = reinterpret_cast<float*>(shifted_accum + i * kNumHeads);
345
+
346
+ auto sum_0 = make_float2(0, 0);
347
+ auto sum_1 = make_float2(0, 0);
348
+
349
+ const auto& transform_reg = [&](const uint32_t& j, const float2& sum) {
350
+ auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
351
+ auto b = make_float2(weights[i][j], weights[i][j + 1]);
352
+ return __ffma2_rn(a, b, sum);
353
+ };
354
+
355
+ #pragma unroll
356
+ for (int j = 0; j < kNumWeightsInReg; j += 4) {
357
+ sum_0 = transform_reg(j, sum_0);
358
+ sum_1 = transform_reg(j + 2, sum_1);
359
+ }
360
+
361
+ const auto& transform_smem = [&](const uint32_t& j, const float2& sum) {
362
+ auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
363
+ auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j),
364
+ ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1));
365
+ return __ffma2_rn(a, b, sum);
366
+ };
367
+
368
+ #pragma unroll
369
+ for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) {
370
+ sum_0 = transform_smem(j, sum_0);
371
+ sum_1 = transform_smem(j + 2, sum_1);
372
+ }
373
+
374
+ auto sum = __fadd2_rn(sum_0, sum_1);
375
+ float result = scale_kv * (sum.x + sum.y);
376
+
377
+ // Store into the global memory
378
+ // NOTES: we have redundant writes here, consider more carefully
379
+ const uint32_t& q_idx = block_q_idx * BLOCK_Q + i;
380
+ if constexpr (kIsCompressedLogits) {
381
+ if (seq_k_start[i] <= kv_offset + v_offset and kv_offset + v_offset < seq_k_end[i])
382
+ logits[q_idx * stride_logits + kv_offset + v_offset - seq_k_start[i]] = result;
383
+ } else {
384
+ logits[q_idx * stride_logits + kv_offset + v_offset] = result;
385
+ }
386
+ }
387
+ }
388
+ num_total_kv_blocks += num_kv_blocks;
389
+
390
+ // Release Q empty
391
+ empty_q_barriers[q_stage_idx]->arrive();
392
+
393
+ // Jump to the next block
394
+ CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
395
+ }
396
+ }
397
+
398
+ // Free tensor memory
399
+ __syncthreads();
400
+ if (is_tma_load_warp)
401
+ cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
402
+ }
403
+
404
+ } // namespace deep_gemm
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cutlass/arch/barrier.h>
4
+ #include <cutlass/arch/reg_reconfig.h>
5
+
6
+ #include <cute/arch/cluster_sm90.hpp>
7
+ #include <cute/arch/copy_sm90_desc.hpp>
8
+
9
+ #include <deep_gemm/common/utils.cuh>
10
+ #include <deep_gemm/common/sm90_utils.cuh>
11
+ #include <deep_gemm/common/sm100_utils.cuh>
12
+
13
+ #include <deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh>
14
+
15
+ namespace deep_gemm {
16
+
17
+ using namespace deep_gemm::sm90;
18
+ using namespace deep_gemm::sm100;
19
+
20
+ template <uint32_t kNextN, uint32_t kNumHeads,
21
+ uint32_t kHeadDim, uint32_t BLOCK_KV,
22
+ bool kIsContextLens2D,
23
+ uint32_t kNumQStages, uint32_t kNumKVStages,
24
+ uint32_t SPLIT_KV,
25
+ uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
26
+ uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
27
+ __global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
28
+ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
29
+ const uint64_t logits_stride, const uint64_t block_table_stride,
30
+ const uint32_t* context_lens, float* logits,
31
+ const uint32_t* block_table, const uint32_t* schedule_meta,
32
+ const __grid_constant__ cute::TmaDescriptor tensor_map_q,
33
+ const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
34
+ const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
35
+ const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
36
+ using Barrier = cutlass::arch::ClusterTransactionBarrier;
37
+
38
+ // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
39
+ const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
40
+ const auto& warpgroup_idx = warp_idx / 4;
41
+ const auto& lane_idx = get_lane_idx();
42
+
43
+ // Prefetch TMA descriptors
44
+ DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
45
+ if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
46
+ cute::prefetch_tma_descriptor(&tensor_map_q);
47
+ cute::prefetch_tma_descriptor(&tensor_map_kv);
48
+ cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
49
+ cute::prefetch_tma_descriptor(&tensor_map_weights);
50
+ }
51
+ __syncwarp();
52
+
53
+ // Shared memory configs
54
+ static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
55
+ static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
56
+ static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
57
+ static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = SPLIT_KV * sizeof(float);
58
+ static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float);
59
+
60
+ // Align to swizzling alignment bytes
61
+ extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
62
+ DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
63
+ DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
64
+
65
+ // Q and KV data on shared memory
66
+ auto smem_q = PatternVisitor([&](const uint32_t& i) {
67
+ return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i);
68
+ });
69
+ auto smem_kv = PatternVisitor([&](const uint32_t& i) {
70
+ return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i);
71
+ });
72
+ constexpr auto smem_offset = SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages;
73
+ auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
74
+ return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * i);
75
+ });
76
+ auto smem_weights = PatternVisitor([&](const uint32_t& i) {
77
+ return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
78
+ });
79
+
80
+ // Barriers and TMEM pointer on shared memory
81
+ const auto barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
82
+ auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
83
+ auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; });
84
+ auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; });
85
+ auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; });
86
+ const auto umma_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2;
87
+ auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; });
88
+ auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; });
89
+ auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(umma_barrier_ptr + kNumMathWarpGroups * 2);
90
+
91
+ constexpr uint32_t kNumTmemCols = kNextN * kNumHeads * kNumMathWarpGroups;
92
+ DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
93
+ const bool& is_math_warp = (warp_idx < kNumMathWarpGroups * 4);
94
+ const bool& is_tma_load_warp = (warp_idx == kNumMathWarpGroups * 4);
95
+ const bool& is_umma_warp = (warp_idx == kNumMathWarpGroups * 4 + 1);
96
+
97
+ // Initialize barriers
98
+ if (is_tma_load_warp and cute::elect_one_sync()) {
99
+ #pragma unroll
100
+ for (uint32_t i = 0; i < kNumQStages; ++ i) {
101
+ full_q_barriers[i]->init(1);
102
+ empty_q_barriers[i]->init(kNumMathThreads);
103
+ }
104
+ #pragma unroll
105
+ for (uint32_t i = 0; i < kNumKVStages; ++ i) {
106
+ full_kv_barriers[i]->init(1);
107
+ empty_kv_barriers[i]->init(kNumMathThreads);
108
+ }
109
+ cutlass::arch::fence_barrier_init();
110
+ }
111
+ if (is_umma_warp) {
112
+ if (cute::elect_one_sync()) {
113
+ #pragma unroll
114
+ for (uint32_t i = 0; i < kNumMathWarpGroups; ++i) {
115
+ full_umma_barriers[i]->init(1);
116
+ empty_umma_barriers[i]->init(128);
117
+ }
118
+ cutlass::arch::fence_barrier_init();
119
+ }
120
+ // Allocate tensor memory
121
+ cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
122
+ }
123
+ __syncthreads();
124
+
125
+ // Register reconfigurations
126
+ constexpr uint32_t kNumSpecializedRegisters = 40;
127
+ constexpr uint32_t kNumMathRegisters = 232;
128
+
129
+ // Scheduler
130
+ constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
131
+ auto scheduler = PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumBlocksPerSplit>(batch_size, blockIdx.x, context_lens, schedule_meta);
132
+ DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
133
+
134
+ // Q and KV pipeline
135
+ const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
136
+ return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase
137
+ };
138
+ const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
139
+ return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase
140
+ };
141
+ uint32_t q_iter_idx = 0, kv_iter_idx = 0;
142
+
143
+ // UMMA settings
144
+ // Construct instruction with layout D
145
+ constexpr uint32_t UMMA_M = 128;
146
+ constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
147
+ constexpr uint32_t UMMA_N = kNextN * kNumHeads;
148
+ DG_STATIC_ASSERT(SPLIT_KV == UMMA_M * kNumMathWarpGroups, "Invalid `SPLIT_KV`");
149
+
150
+ if (is_tma_load_warp) {
151
+ // TMA warp-group for loading data
152
+ cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
153
+
154
+ const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) {
155
+ if (cute::elect_one_sync()) {
156
+ tma_copy<kHeadDim, kNextN * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads);
157
+ tma_copy<kNextN * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx);
158
+ full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
159
+ }
160
+ };
161
+
162
+ // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
163
+ uint32_t q_idx = batch_size, kv_idx, num_kv;
164
+ uint32_t next_q_idx, next_kv_idx, next_num_kv;
165
+ bool fetched_next_task;
166
+
167
+ // Prefetch the first Q
168
+ if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)))
169
+ issue_tma_q(0, next_q_idx), q_iter_idx = 1;
170
+
171
+ int kv_block_idx_ptr = 32;
172
+ uint32_t kv_block_idx_storage;
173
+
174
+ while (fetched_next_task) {
175
+ // Prefetch next Q when current Q changes
176
+ bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1));
177
+ q_idx = next_q_idx;
178
+ kv_idx = next_kv_idx;
179
+ num_kv = next_num_kv;
180
+
181
+ // Read KV block index
182
+ // TODO: deal with `-1`?
183
+ if (kv_idx == 0 or kv_block_idx_ptr == 32) {
184
+ kv_block_idx_ptr = 0;
185
+ kv_block_idx_storage = (kv_idx + lane_idx < num_kv ? __ldg(block_table + q_idx * block_table_stride + (kv_idx + lane_idx)) : 0);
186
+ }
187
+ DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`");
188
+
189
+ // Wait Q consumer release and issue TMA Q
190
+ if (prefetch_q) {
191
+ CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
192
+ empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
193
+ issue_tma_q(q_stage_idx, q_idx + 1);
194
+ }
195
+
196
+ int kv_block_idx[kNumBlocksPerSplit];
197
+ #pragma unroll
198
+ for (int i = 0; i < kNumBlocksPerSplit; ++ i)
199
+ kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i);
200
+ kv_block_idx_ptr += kNumBlocksPerSplit;
201
+
202
+ // Wait KV consumer release
203
+ CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
204
+ empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
205
+
206
+ if (cute::elect_one_sync()) {
207
+ #pragma unroll
208
+ for (int i = 0; i < kNumBlocksPerSplit; ++ i) {
209
+ tma_copy<kHeadDim, BLOCK_KV, 0, __nv_fp8_e4m3, true>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
210
+ smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i,
211
+ 0, 0, 1, kv_block_idx[i]);
212
+ tma_copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
213
+ smem_kv_scales[kv_stage_idx] + BLOCK_KV * i,
214
+ 0, kv_block_idx[i]);
215
+ }
216
+ full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
217
+ }
218
+
219
+ // Fetch next task
220
+ fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv);
221
+ }
222
+ } else if (is_umma_warp) {
223
+ cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
224
+
225
+ // Require full allocation
226
+ DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
227
+
228
+ // Make UMMA desc
229
+ auto instr_desc = cute::UMMA::make_instr_desc<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
230
+ UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
231
+ auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
232
+
233
+ uint32_t q_idx = batch_size, kv_idx;
234
+ uint32_t next_q_idx, next_kv_idx, next_num_kv;
235
+ uint32_t q_stage_idx, q_phase;
236
+ uint32_t umma_phase = 1;
237
+
238
+ while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) {
239
+ if (q_idx != next_q_idx) {
240
+ CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
241
+ full_q_barriers[q_stage_idx]->wait(q_phase);
242
+ }
243
+
244
+ q_idx = next_q_idx;
245
+ kv_idx = next_kv_idx;
246
+
247
+ CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
248
+ full_kv_barriers[kv_stage_idx]->wait(kv_phase);
249
+
250
+ DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim");
251
+ #pragma unroll
252
+ for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
253
+ empty_umma_barriers[i]->wait(umma_phase);
254
+ tcgen05_after_thread_sync();
255
+ #pragma unroll
256
+ for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
257
+ auto a_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
258
+ smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K);
259
+ auto b_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
260
+ smem_q[q_stage_idx], 0, k * UMMA_K);
261
+ cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc);
262
+ }
263
+ cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(full_umma_barriers[i]));
264
+ }
265
+ umma_phase ^= 1;
266
+ }
267
+ } else if (is_math_warp) {
268
+ // Math warp-groups for WGMMA
269
+ cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
270
+
271
+ // Offsets
272
+ const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0);
273
+ const uint32_t thread_idx = threadIdx.x;
274
+
275
+ // Weights
276
+ constexpr uint32_t kNumWeightsInReg = (kNextN == 1 ? kNumHeads : cute::min(48, kNumHeads));
277
+ float weights[kNextN][kNumWeightsInReg];
278
+ DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers");
279
+
280
+ // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
281
+ uint32_t q_idx = batch_size, kv_idx;
282
+ uint32_t next_q_idx, next_kv_idx, next_num_kv;
283
+ uint32_t q_stage_idx, q_phase;
284
+ uint32_t umma_phase = 0;
285
+
286
+ while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) {
287
+ // Current Q changes
288
+ if (q_idx != next_q_idx) {
289
+ // Release Last Q empty
290
+ if (q_iter_idx > 0)
291
+ empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive();
292
+
293
+ // Wait TMA Q arrival
294
+ CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
295
+ full_q_barriers[q_stage_idx]->wait(q_phase);
296
+
297
+ // Read weights
298
+ #pragma unroll
299
+ for (uint32_t i = 0; i < kNextN; ++ i) {
300
+ for (uint32_t j = 0; j < kNumWeightsInReg; ++ j)
301
+ weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
302
+ }
303
+ }
304
+
305
+ // Get current Q and KV index
306
+ q_idx = next_q_idx;
307
+ kv_idx = next_kv_idx;
308
+
309
+ // Calculate KV offset in advance
310
+ auto kv_offset = q_idx * kNextN * logits_stride + kv_idx * BLOCK_KV;
311
+
312
+ // Compute `[kNextN * kNumHeads, kHeadDim] @ [SPLIT_KV, kHeadDim] -> [kNextN, SPLIT_KV]`
313
+ // Wait TMA KV arrival
314
+ CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
315
+ full_kv_barriers[kv_stage_idx]->wait(kv_phase);
316
+
317
+ // Read per-KV scales
318
+ float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + thread_idx);
319
+
320
+ // Wait UMMA arrival
321
+ full_umma_barriers[warpgroup_idx]->wait(umma_phase);
322
+ tcgen05_after_thread_sync();
323
+ umma_phase ^= 1;
324
+
325
+ // Release KV empty
326
+ empty_kv_barriers[kv_stage_idx]->arrive();
327
+
328
+ // Reduce over the head dim and store
329
+ DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
330
+ constexpr uint32_t kNumLDTMElems = kNumHeads * kNextN;
331
+ uint32_t shifted_accum[kNumLDTMElems];
332
+ DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid LDTM");
333
+ auto tmem_load = [&](auto... Is) {
334
+ if constexpr (kNumLDTMElems == 32) {
335
+ cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...);
336
+ } else if constexpr (kNumLDTMElems == 64) {
337
+ cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...);
338
+ } else if constexpr (kNumLDTMElems == 128) {
339
+ cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...);
340
+ }
341
+ };
342
+ [&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
343
+ cutlass::arch::fence_view_async_tmem_load();
344
+
345
+ tcgen05_before_thread_sync();
346
+ empty_umma_barriers[warpgroup_idx]->arrive();
347
+
348
+ #pragma unroll
349
+ for (uint32_t i = 0; i < kNextN; ++ i) {
350
+ auto accum = reinterpret_cast<float*>(shifted_accum + i * kNumHeads);
351
+
352
+ auto sum_0 = make_float2(0, 0);
353
+ auto sum_1 = make_float2(0, 0);
354
+
355
+ const auto& transform_reg = [&](const uint32_t& j, const float2& sum) {
356
+ auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
357
+ auto b = make_float2(weights[i][j], weights[i][j + 1]);
358
+ return __ffma2_rn(a, b, sum);
359
+ };
360
+
361
+ #pragma unroll
362
+ for (int j = 0; j < kNumWeightsInReg; j += 4) {
363
+ sum_0 = transform_reg(j, sum_0);
364
+ sum_1 = transform_reg(j + 2, sum_1);
365
+ }
366
+
367
+ const auto& transform_smem = [&](const uint32_t& j, const float2& sum) {
368
+ auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
369
+ auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j),
370
+ ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1));
371
+ return __ffma2_rn(a, b, sum);
372
+ };
373
+
374
+ #pragma unroll
375
+ for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) {
376
+ sum_0 = transform_smem(j, sum_0);
377
+ sum_1 = transform_smem(j + 2, sum_1);
378
+ }
379
+
380
+ auto sum = __fadd2_rn(sum_0, sum_1);
381
+ float result = scale_kv * (sum.x + sum.y);
382
+
383
+ // Store into the global memory
384
+ // NOTES: we have redundant writes here, consider more carefully
385
+ logits[kv_offset + i * logits_stride + thread_idx] = result;
386
+ }
387
+ }
388
+ } else {
389
+ cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
390
+ }
391
+
392
+ // Free tensor memory
393
+ __syncthreads();
394
+ if (is_umma_warp)
395
+ cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
396
+ }
397
+
398
+ } // namespace deep_gemm
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #pragma clang diagnostic push
3
+ #pragma clang diagnostic ignored "-Wunknown-attributes"
4
+
5
+ #include <cutlass/arch/barrier.h>
6
+
7
+ #include <deep_gemm/common/reduction.cuh>
8
+ #include <deep_gemm/common/utils.cuh>
9
+ #include <deep_gemm/common/sm90_utils.cuh>
10
+ #include <deep_gemm/common/sm100_utils.cuh>
11
+
12
+ namespace deep_gemm {
13
+
14
+ using namespace deep_gemm::sm100;
15
+
16
+ template <uint32_t kSwizzleMode, uint32_t kSwizzleBase = 16>
17
+ __device__ __forceinline__
18
+ uint32_t get_swizzled_smem_offset(const uint32_t& offset, const uint32_t& lane_idx) {
19
+ // Calculate the index of the bank group to be written in the atom
20
+ const auto& bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase);
21
+
22
+ // Reshape the atom in another view and swizzle
23
+ // - original: `(BLOCK_N, kSwizzleMode / kSwizzleBase)`
24
+ // - new: `(BLOCK_N * kSwizzleMode / kSwizzleBase / kNumBankGroups, kNumBankGroups)`
25
+ constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase;
26
+ constexpr bool kHasShortcut = (kSwizzleMode / kSwizzleBase) == kNumBankGroups;
27
+ auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups);
28
+ auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups);
29
+ col ^= row % (kSwizzleMode / kSwizzleBase);
30
+
31
+ return row * 128 + col * kSwizzleBase;
32
+ }
33
+
34
+ template <uint32_t SHAPE_N, uint32_t SHAPE_K,
35
+ uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
36
+ uint32_t kNumSplits,
37
+ uint32_t kSwizzleCDMode,
38
+ uint32_t kNumStages,
39
+ uint32_t kNumMMAThreads, uint32_t kNumCastAndReduceThreads>
40
+ __global__ void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1)
41
+ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
42
+ const __grid_constant__ cute::TmaDescriptor tensor_map_a,
43
+ const __grid_constant__ cute::TmaDescriptor tensor_map_b,
44
+ const __grid_constant__ cute::TmaDescriptor tensor_map_d,
45
+ float* sqr_sum) {
46
+ #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
47
+ using Barrier = cutlass::arch::ClusterTransactionBarrier;
48
+
49
+ // Configs
50
+ constexpr uint32_t kNumCastStages = 2;
51
+ constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128);
52
+ constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128);
53
+ constexpr auto kMajorA = cute::UMMA::Major::K;
54
+ constexpr auto kMajorB = cute::UMMA::Major::K;
55
+ DG_STATIC_ASSERT(kNumCastStages <= kNumStages, "Invalid cast stages");
56
+ DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N");
57
+ DG_STATIC_ASSERT(kNumMMAThreads == 128, "Invalid MMA threads");
58
+
59
+ // Utils
60
+ const auto warp_idx = cutlass::canonical_warp_idx_sync();
61
+ const auto lane_idx = get_lane_idx();
62
+
63
+ // Align to 1024 bytes for swizzle-128B
64
+ extern __shared__ __align__(1024) uint8_t smem_buffer[];
65
+
66
+ // Share memory sizes
67
+ constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode;
68
+ constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16);
69
+ constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float);
70
+ DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
71
+
72
+ // Real tensor memory size and offsets
73
+ constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<BLOCK_K * kNumCastStages + BLOCK_N>();
74
+
75
+ // Prefetch TMA descriptors at the very beginning
76
+ if (warp_idx == 0 and cute::elect_one_sync()) {
77
+ cute::prefetch_tma_descriptor(&tensor_map_a);
78
+ cute::prefetch_tma_descriptor(&tensor_map_b);
79
+ cute::prefetch_tma_descriptor(&tensor_map_d);
80
+ }
81
+
82
+ // Data on shared memory (layout as ordered below)
83
+ // Fill D/A/B pointers
84
+ auto smem_cd = reinterpret_cast<float*>(smem_buffer);
85
+ auto smem_a = PatternVisitor([&](const uint32_t& i) {
86
+ return reinterpret_cast<nv_bfloat16*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
87
+ });
88
+ auto smem_b = PatternVisitor([&](const uint32_t& i) {
89
+ return reinterpret_cast<float*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
90
+ });
91
+
92
+ // Fill barriers
93
+ auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
94
+ kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
95
+ auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
96
+ auto full_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
97
+ auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
98
+ auto empty_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); });
99
+ auto tmem_full_barrier = barrier_start_ptr + kNumStages * 4;
100
+
101
+ // Fill the tensor memory pointer
102
+ auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 4 + 1);
103
+ DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
104
+
105
+ // Initialize barriers
106
+ if (warp_idx == 1 and cute::elect_one_sync()) {
107
+ #pragma unroll
108
+ for (uint32_t i = 0; i < kNumStages; ++ i) {
109
+ full_barriers[i]->init(1);
110
+ full_cast_barriers[i]->init(kNumCastAndReduceThreads);
111
+ empty_barriers[i]->init(1);
112
+ empty_cast_barriers[i]->init(1);
113
+ }
114
+ tmem_full_barrier->init(1);
115
+
116
+ // Make initialized barrier visible in async proxy
117
+ cutlass::arch::fence_barrier_init();
118
+ } else if (warp_idx == 2) {
119
+ // Allocate tensor memory
120
+ cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
121
+ }
122
+ __syncthreads();
123
+
124
+ constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K);
125
+ constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits;
126
+ constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits;
127
+ const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0);
128
+ const uint32_t m_block_idx = block_idx / kNumSplits;
129
+ const uint32_t k_split_idx = block_idx % kNumSplits;
130
+ const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K;
131
+ const uint32_t m_offset = shape_m * k_split_idx;
132
+ const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks);
133
+
134
+ // Dispatch warps into different roles
135
+ if (warp_idx < kNumMMAThreads / 32) {
136
+ // TMA load warp
137
+ if (warp_idx == 0 and cute::elect_one_sync()) {
138
+ for (uint32_t s = 0; s < num_total_stages; ++ s) {
139
+ // Wait consumer release
140
+ const auto& stage_idx = s % kNumStages;
141
+ empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
142
+
143
+ // Compute offsets
144
+ uint32_t m_idx = m_block_idx * BLOCK_M;
145
+ uint32_t k_idx = k_offset + s * BLOCK_K;
146
+
147
+ // Issue TMAs
148
+ tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
149
+ tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0);
150
+
151
+ // Arrive at full barriers
152
+ constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
153
+ full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes);
154
+ }
155
+ }
156
+
157
+ // MMA issue warp
158
+ if (warp_idx == 1) {
159
+ // Make instruction descriptor
160
+ constexpr uint32_t UMMA_M = BLOCK_M;
161
+ constexpr uint32_t UMMA_N = BLOCK_N;
162
+ constexpr uint32_t UMMA_K = 32 / sizeof(float);
163
+ constexpr uint32_t BLOCK_SWIZZLED_BK = kSwizzleBMode / sizeof(float);
164
+ using umma_t = cute::SM100_MMA_TF32_TS<cutlass::tfloat32_t, cutlass::tfloat32_t, float,
165
+ BLOCK_M, BLOCK_N, kMajorA, kMajorB>;
166
+ auto instr_desc = cute::UMMA::make_instr_desc<cutlass::tfloat32_t, cutlass::tfloat32_t, float,
167
+ UMMA_M, UMMA_N, kMajorA, kMajorB>();
168
+ const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
169
+
170
+ DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
171
+ auto b_desc = make_umma_desc<kMajorB, BLOCK_N, BLOCK_SWIZZLED_BK, kSwizzleBMode>(smem_b[0], 0, 0);
172
+ const uint32_t& b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
173
+
174
+ // Checks for MMA instructions
175
+ // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
176
+ DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
177
+ (UMMA_M == 128 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
178
+ (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
179
+ "Invalid MMA instruction shape");
180
+
181
+ // Launch MMAs
182
+ // We can not unroll this part
183
+ for (uint32_t s = 0; s < num_total_stages; ++ s) {
184
+ // Wait TMA arrival
185
+ const auto& stage_idx = s % kNumStages;
186
+ const auto& cast_stage_idx = s % kNumCastStages;
187
+ full_cast_barriers[cast_stage_idx]->wait((s / kNumCastStages) & 1);
188
+ tcgen05_after_thread_sync();
189
+
190
+ // Issue UMMA
191
+ const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
192
+ #pragma unroll
193
+ for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
194
+ const uint32_t& atom_idx = (k * UMMA_K) / BLOCK_SWIZZLED_BK;
195
+ const uint32_t& in_atom_idx = (k * UMMA_K) % BLOCK_SWIZZLED_BK;
196
+ const uint32_t& offset = atom_idx * BLOCK_N * BLOCK_SWIZZLED_BK;
197
+ b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, float>(b_desc_base_lo, offset, in_atom_idx);
198
+ umma_t::fma(BLOCK_K * cast_stage_idx + k * UMMA_K, b_desc, BLOCK_K * kNumCastStages, s > 0 or k > 0, runtime_instr_desc);
199
+ }
200
+
201
+ // Commit
202
+ cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_cast_barriers[cast_stage_idx]));
203
+ cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
204
+ }
205
+
206
+ // Commit to epilogue threads
207
+ cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barrier));
208
+ }
209
+
210
+ // TMA checks
211
+ constexpr uint32_t kNumBankGroupBytes = 16;
212
+ constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float);
213
+ DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
214
+ DG_STATIC_ASSERT(BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
215
+
216
+ // Only support layout F (M = 64) and D (M = 128)
217
+ DG_STATIC_ASSERT(BLOCK_M == 64 or BLOCK_M == 128, "Invalid block M");
218
+
219
+ // Wait UMMA arrival
220
+ tmem_full_barrier->wait(0);
221
+ tcgen05_after_thread_sync();
222
+
223
+ // Load from tensor memory into registers, and write shared memory with STSM
224
+ DG_STATIC_ASSERT(kNumMMAThreads == 128, "Epilogue threads not enough");
225
+
226
+ // Store into shared memory
227
+ #pragma unroll
228
+ for (uint32_t i = 0; i < BLOCK_N / kNumElemsPerBankGroup; ++ i) {
229
+ // Source and destination memory address
230
+ uint32_t tmem_addr = BLOCK_K * kNumCastStages + i * kNumElemsPerBankGroup;
231
+ auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd) + // Base pointer
232
+ warp_idx * BLOCK_M / 4 * kSwizzleCDMode + // Warp offset
233
+ get_swizzled_smem_offset<kSwizzleCDMode>(i, lane_idx); // In-atom offset
234
+
235
+ // Load from tensor memory, store into shared memory
236
+ uint32_t values[kNumElemsPerBankGroup];
237
+ DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
238
+ cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
239
+ values[0], values[1], values[2], values[3]);
240
+ cutlass::arch::fence_view_async_tmem_load();
241
+ if (BLOCK_M == 128 or (BLOCK_M == 64 and lane_idx < 16))
242
+ st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
243
+ if constexpr (BLOCK_M == 64)
244
+ __syncwarp();
245
+ }
246
+
247
+ // Synchronize all threads and issue TMA
248
+ cute::tma_store_fence();
249
+ cutlass::arch::NamedBarrier::sync(kNumMMAThreads, 0);
250
+ if (warp_idx == 0 and cute::elect_one_sync()) {
251
+ if constexpr (kNumSplits == 1) {
252
+ cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M);
253
+ } else {
254
+ cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx);
255
+ }
256
+ cute::tma_store_arrive();
257
+ }
258
+
259
+ // Deallocate tensor memory by warp 1
260
+ // NOTES: warp 0 is waiting TMA store
261
+ if (warp_idx == 1)
262
+ cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
263
+ } else {
264
+ DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M");
265
+ DG_STATIC_ASSERT(kNumCastAndReduceThreads == 128, "Invalid cast-and-reduce threads");
266
+ constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4;
267
+ const uint32_t sub_warp_idx = warp_idx - kNumMMAThreads / 32;
268
+
269
+ // TODO: make even larger block K
270
+ DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K");
271
+
272
+ // Launch reductions
273
+ float2 sum[2] = {float2{0, 0}, float2{0, 0}};
274
+ #pragma unroll kNumStages
275
+ for (uint32_t s = 0; s < num_total_stages; ++ s) {
276
+ // Wait TMA arrival
277
+ const auto& stage_idx = s % kNumStages;
278
+ full_barriers[stage_idx]->wait((s / kNumStages) & 1);
279
+
280
+ // Load from shared memory into tensor memory using movement shape `.16x256b` (shared memory part is 128b)
281
+ constexpr uint32_t kNumBankGroupBytes = 16;
282
+ constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16);
283
+ constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup;
284
+ const auto& smem_base_ptr = reinterpret_cast<uint8_t*>(smem_a[stage_idx]) + // Base pointer
285
+ sub_warp_idx * BLOCK_M_PER_WARP * kSwizzleAMode; // Warp offset
286
+
287
+ // 4 lanes shared a bank group
288
+ uint32_t uint32_values[2][kNumLoads];
289
+ DG_STATIC_ASSERT(kNumLoads % 2 == 0, "Invalid number of loads");
290
+ #pragma unroll
291
+ for (uint32_t i = 0; i < kNumLoads; i += 2) {
292
+ auto smem_ptr = smem_base_ptr + get_swizzled_smem_offset<kSwizzleAMode>(i + lane_idx / 16, lane_idx % 16);
293
+ sm90::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0],
294
+ uint32_values[0][i + 1], uint32_values[1][i + 1],
295
+ smem_ptr);
296
+ }
297
+
298
+ // Wait tensor memory empty
299
+ const auto& cast_stage_idx = s % kNumCastStages;
300
+ empty_cast_barriers[cast_stage_idx]->wait(((s / kNumCastStages) & 1) ^ 1);
301
+
302
+ // Cast, reduce and store into tensor memory
303
+ float2 fp32x2_values[2][kNumLoads];
304
+ const auto& upper_view = reinterpret_cast<uint32_t*>(&fp32x2_values[0]);
305
+ const auto& lower_view = reinterpret_cast<uint32_t*>(&fp32x2_values[1]);
306
+ #pragma unroll
307
+ for (uint32_t i = 0; i < kNumLoads; ++ i) {
308
+ #pragma unroll
309
+ for (uint32_t u = 0; u < 2; ++ u) {
310
+ fp32x2_values[u][i] = __bfloat1622float2(*reinterpret_cast<nv_bfloat162*>(&uint32_values[u][i]));
311
+ sum[u] = __ffma2_rn(fp32x2_values[u][i], fp32x2_values[u][i], sum[u]);
312
+ }
313
+
314
+ // Store upper and lower part at the same time
315
+ const auto idx_0 = i * 2, idx_1 = i * 2 + 1;
316
+ cute::SM100_TMEM_STORE_16dp256b1x::copy(
317
+ upper_view[idx_0], upper_view[idx_1],
318
+ lower_view[idx_0], lower_view[idx_1],
319
+ cast_stage_idx * BLOCK_K + i * 8);
320
+ }
321
+ cutlass::arch::fence_view_async_tmem_store();
322
+
323
+ // Arrive for issuing MMAs
324
+ tcgen05_before_thread_sync();
325
+ full_cast_barriers[cast_stage_idx]->arrive();
326
+ }
327
+
328
+ // Intra-warp reduction and write back
329
+ #pragma unroll
330
+ for (uint32_t u = 0; u < 2; ++ u) {
331
+ const auto& reduced_sum = warp_reduce_sum<4>(sum[u].x + sum[u].y);
332
+ const auto& m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8;
333
+ if (lane_idx % 4 == 0 and m_idx < shape_m)
334
+ sqr_sum[m_offset + m_idx] = reduced_sum;
335
+ }
336
+ }
337
+ #else
338
+ if (blockIdx.x == 0 and threadIdx.x == 0)
339
+ DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
340
+ #endif
341
+ }
342
+
343
+ } // namespace deep_gemm
344
+
345
+ #pragma clang diagnostic pop
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm90_bf16_gemm.cuh ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #pragma clang diagnostic push
4
+ #pragma clang diagnostic ignored "-Wunknown-attributes"
5
+
6
+ #include <cutlass/arch/barrier.h>
7
+ #include <cutlass/arch/reg_reconfig.h>
8
+
9
+ #include <cute/arch/cluster_sm90.hpp>
10
+ #include <cute/arch/copy_sm90_desc.hpp>
11
+ #include <cute/arch/copy_sm90_tma.hpp>
12
+ #include <cute/arch/mma_sm100_desc.hpp>
13
+
14
+ #include <deep_gemm/common/utils.cuh>
15
+ #include <deep_gemm/common/scheduler.cuh>
16
+ #include <deep_gemm/common/sm90_utils.cuh>
17
+
18
+ namespace deep_gemm {
19
+
20
+ using namespace deep_gemm::sm90;
21
+
22
+ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
23
+ uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
24
+ uint32_t kNumGroups,
25
+ uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K_,
26
+ uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleDMode,
27
+ uint32_t kNumStages_,
28
+ uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
29
+ uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
30
+ uint32_t kNumSMs,
31
+ GemmType kGemmType, bool kWithAccumulation,
32
+ typename cd_dtype_t>
33
+ __global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
34
+ sm90_bf16_gemm_impl(int* grouped_layout,
35
+ uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
36
+ const __grid_constant__ cute::TmaDescriptor tensor_map_a,
37
+ const __grid_constant__ cute::TmaDescriptor tensor_map_b,
38
+ const __grid_constant__ cute::TmaDescriptor tensor_map_cd) {
39
+ #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
40
+ // Enlarge `BLOCK_K` for some cases
41
+ // NOTES: this is for reducing the `warpgroup_wait<0>()` overhead
42
+ constexpr uint32_t kDoMergeStages =
43
+ kNumStages_ >= 10 and
44
+ kGemmType == GemmType::Normal and
45
+ kMajorA == cute::UMMA::Major::K and kMajorB == cute::UMMA::Major::K and
46
+ kNumMathThreads == 128;
47
+ // Ensure there are at least `kNumMinStages` stages after merge
48
+ constexpr uint32_t kNumMinStages = 5;
49
+ constexpr uint32_t kNumStagesPerMerge = kDoMergeStages ? kNumStages_ / kNumMinStages : 1;
50
+ constexpr uint32_t BLOCK_K = BLOCK_K_ * kNumStagesPerMerge;
51
+ constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge;
52
+
53
+ // Types
54
+ using WGMMA = typename BF16MMASelector<BLOCK_N, kMajorA, kMajorB>::type;
55
+ using Barrier = cutlass::arch::ClusterTransactionBarrier;
56
+ DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size");
57
+
58
+ // Overwrite shape constants if the compiler gives
59
+ shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
60
+ shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
61
+ shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
62
+
63
+ // Shared memory
64
+ static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(cd_dtype_t)), 1024u);
65
+ static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16);
66
+ static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16);
67
+
68
+ // NOTES: Make sure we have enough shared memory for WGMMA padding
69
+ static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3);
70
+ DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA");
71
+
72
+ // Configs
73
+ const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
74
+ const uint32_t lane_idx = get_lane_idx();
75
+
76
+ // Prefetch TMA descriptors at the very beginning
77
+ if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
78
+ cute::prefetch_tma_descriptor(&tensor_map_a);
79
+ cute::prefetch_tma_descriptor(&tensor_map_b);
80
+ cute::prefetch_tma_descriptor(&tensor_map_cd);
81
+ }
82
+ __syncwarp();
83
+
84
+ // Align to 1024 bytes for swizzle-128B
85
+ extern __shared__ __align__(1024) uint8_t smem_buffer[];
86
+ DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0,
87
+ "Shared memory of A/B/D must be aligned to 1024 bytes");
88
+
89
+ // D/A/B shared memory
90
+ auto smem_d = reinterpret_cast<cd_dtype_t*>(smem_buffer);
91
+ auto smem_a = PatternVisitor([&](const uint32_t& i) {
92
+ return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
93
+ });
94
+ auto smem_b = PatternVisitor([&](const uint32_t& i) {
95
+ return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
96
+ });
97
+
98
+ // Fill barriers
99
+ auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
100
+ auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
101
+ auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
102
+
103
+ // Initialize barriers
104
+ if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
105
+ #pragma unroll
106
+ for (uint32_t i = 0; i < kNumStages; ++ i) {
107
+ full_barriers[i]->init(1);
108
+ empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
109
+ }
110
+
111
+ // Make initialized barrier visible in async proxy
112
+ cutlass::arch::fence_barrier_init();
113
+ }
114
+
115
+ // Synchronize all threads to make barrier visible in normal memory model
116
+ (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
117
+
118
+ // Register reconfigurations
119
+ constexpr uint32_t kNumTMARegisters = 48;
120
+ constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 224;
121
+
122
+ // Block scheduler
123
+ uint32_t m_block_idx, n_block_idx;
124
+ auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
125
+
126
+ // Pipeline and TMA phases
127
+ uint32_t stage_idx = 0, phase = 0;
128
+ auto advance_pipeline = [&](uint32_t& k_block_idx) {
129
+ ++ k_block_idx;
130
+
131
+ // Flip phases only if reach the next first stage
132
+ stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1;
133
+ phase ^= stage_idx == 0;
134
+ };
135
+
136
+ if (warp_idx >= kNumMathThreads / 32) {
137
+ // TMA warp-group for loading data
138
+ cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
139
+
140
+ // NOTES: only one thread (or warp) will be used
141
+ // We use the third warp, as warp 0/1 may be doing WGMMA with `BLOCK_M == 32`
142
+ if (warp_idx == kNumMathThreads / 32 + 2 and cute::elect_one_sync()) {
143
+ DG_STATIC_ASSERT(kNumTMAThreads >= 128, "Need at least 128 threads for TMA warp-group");
144
+
145
+ // Persistently schedule over blocks
146
+ while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
147
+ // Assign TMA multicast number into A and B
148
+ // NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
149
+ const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
150
+ const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
151
+ const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
152
+ DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
153
+
154
+ const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
155
+ for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
156
+ // Wait consumer release
157
+ empty_barriers[stage_idx]->wait(phase ^ 1);
158
+
159
+ constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
160
+ auto& full_barrier = *full_barriers[stage_idx];
161
+
162
+ const auto m_idx = scheduler.template get_global_idx<kWithGroupOffsetA, IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
163
+ const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx);
164
+
165
+ DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
166
+ uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> (
167
+ shape_k, BLOCK_K, k_block_idx, m_block_idx);
168
+ uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> (
169
+ shape_k, BLOCK_K, k_block_idx, m_block_idx);
170
+
171
+ // Issue TMAs
172
+ constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
173
+ const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
174
+ if constexpr (kMajorA == cute::UMMA::Major::K)
175
+ tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
176
+ &tensor_map_a, &full_barrier, smem_a[stage_idx], k_a_idx, m_idx, num_tma_multicast_a, batch_idx);
177
+ if constexpr (kMajorA == cute::UMMA::Major::MN)
178
+ tma_copy<BLOCK_M, BLOCK_K, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
179
+ &tensor_map_a, &full_barrier, smem_a[stage_idx], m_idx, k_a_idx, num_tma_multicast_a, batch_idx);
180
+ if constexpr (kMajorB == cute::UMMA::Major::K)
181
+ tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
182
+ &tensor_map_b, &full_barrier, smem_b[stage_idx], k_b_idx, n_idx, num_tma_multicast_b, batch_idx);
183
+ if constexpr (kMajorB == cute::UMMA::Major::MN)
184
+ tma_copy<BLOCK_N, BLOCK_K, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
185
+ &tensor_map_b, &full_barrier, smem_b[stage_idx], n_idx, k_b_idx, num_tma_multicast_b, batch_idx);
186
+
187
+ full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
188
+ }
189
+ }
190
+
191
+ // To safely deconstruct distributed shared barriers, we need another round of empty waits
192
+ if constexpr (kNumTMAMulticast > 1) {
193
+ for (uint32_t i = 0; i < kNumStages; advance_pipeline(i))
194
+ empty_barriers[stage_idx]->wait(phase ^ 1);
195
+ }
196
+ }
197
+ } else {
198
+ // Math warp-groups for WGMMA
199
+ cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
200
+
201
+ // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
202
+ const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
203
+
204
+ // Merged stages only happens in NT normal GEMM cases
205
+ constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge;
206
+ auto a_desc = make_gmma_desc<kMajorA, BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode>(smem_a[0], math_wg_idx * WGMMA::M, 0);
207
+ auto b_desc = make_gmma_desc<kMajorB, BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode>(smem_b[0], 0, 0);
208
+ const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0);
209
+ const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0);
210
+
211
+ while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
212
+ constexpr uint32_t WAVE_BLOCK_M = BLOCK_M <= WGMMA::M ? BLOCK_M : WGMMA::M * 2;
213
+ DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes");
214
+ float accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0};
215
+
216
+ // Pick threads whose WGMMA results are to be stored in shared memory
217
+ DG_STATIC_ASSERT(BLOCK_M >= 64 or kNumMathThreads == 128, "Only one math warp group for `BLOCK_M < 64`");
218
+ constexpr uint32_t kNumWGMMAStoreThreads = WAVE_BLOCK_M * (128 / WGMMA::M);
219
+ const bool do_wgmma_store = BLOCK_M >= 64 or warp_idx < kNumWGMMAStoreThreads / 32;
220
+
221
+ // Empty barrier arrival
222
+ auto empty_barrier_arrive = [&](uint32_t s) {
223
+ if constexpr (kNumTMAMulticast == 1) {
224
+ lane_idx == 0 ? empty_barriers[s]->arrive() : void();
225
+ } else {
226
+ auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
227
+ lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
228
+ }
229
+ };
230
+
231
+ // TODO: remove some useless computation for unaligned Ms
232
+ const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
233
+ for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
234
+ const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16);
235
+ const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16);
236
+
237
+ // Wait TMA arrivals
238
+ full_barriers[stage_idx]->wait(phase);
239
+
240
+ // Commit WGMMA instructions
241
+ #pragma unroll
242
+ for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i)
243
+ warpgroup_fence_operand(accum[i]);
244
+ warpgroup_arrive();
245
+ #pragma unroll
246
+ for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
247
+ auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
248
+ #pragma unroll
249
+ for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
250
+ const uint32_t& atom_k_idx = k * WGMMA::K / BLOCK_ATOM_K;
251
+ a_desc.reg32_[0] = advance_gmma_desc_lo<kMajorA, BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode, nv_bfloat16>(
252
+ a_desc_base_lo, local_idx * WAVE_BLOCK_M, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_M * BLOCK_ATOM_K);
253
+ b_desc.reg32_[0] = advance_gmma_desc_lo<kMajorB, BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode, nv_bfloat16>(
254
+ b_desc_base_lo, 0, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_N * BLOCK_ATOM_K);
255
+ WGMMA::wgmma(a_desc, b_desc, shifted_accum, 1);
256
+ }
257
+ }
258
+ warpgroup_commit_batch();
259
+ #pragma unroll
260
+ for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i)
261
+ warpgroup_fence_operand(accum[i]);
262
+ warpgroup_wait<0>();
263
+
264
+ // Notify barrier arrival
265
+ empty_barrier_arrive(stage_idx);
266
+ }
267
+
268
+ // TMA checks
269
+ constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
270
+ constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes);
271
+ constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4;
272
+ DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom");
273
+ DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32,
274
+ "Unaligned TMA store or too many TMA store instructions");
275
+ DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N");
276
+
277
+ // Skip WGMMA store for the unfilled parts
278
+ if (not do_wgmma_store)
279
+ continue;
280
+
281
+ // Wait last TMA store to be finished
282
+ if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N)
283
+ cute::tma_store_wait<0>();
284
+ cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0);
285
+
286
+ if constexpr (cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>) {
287
+ // Write back to shared memory using STSM and issue TMA stores
288
+ DG_STATIC_ASSERT(kSwizzleDMode > 0, "Invalid swizzling type");
289
+ DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
290
+ #pragma unroll
291
+ for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
292
+ auto m_offset = local_idx * WAVE_BLOCK_M;
293
+ auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
294
+ #pragma unroll
295
+ for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
296
+ // Swizzle or padding into the correct address
297
+ uint8_t* smem_ptr = nullptr;
298
+ if constexpr (kSwizzleDMode > 0) {
299
+ // Calculate the swizzling atom offset and in-atom offset
300
+ constexpr uint32_t kNumBankGroupBytes = 16;
301
+ auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
302
+
303
+ // Calculate the index of the bank group to be written in the atom
304
+ auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
305
+
306
+ // Reshape the atom in another view and swizzle
307
+ // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)`
308
+ // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)`
309
+ constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8;
310
+ auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8);
311
+ auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8);
312
+ col ^= row % (kSwizzleDMode / 16);
313
+
314
+ // Add back into the base pointer
315
+ // NOTES: think twice before modifying this, as changes may affect the number of instructions
316
+ smem_ptr = reinterpret_cast<uint8_t*>(smem_d) + // Base pointer
317
+ warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset
318
+ m_offset * kSwizzleDMode + // Wave offset
319
+ atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants)
320
+ row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
321
+ } else {
322
+ // No swizzling
323
+ smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8);
324
+ }
325
+
326
+ // NOTES: only 16 lanes' addresses are used
327
+ SM90_U32x2_STSM_N<nv_bfloat162>::copy(
328
+ __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
329
+ __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
330
+ smem_ptr
331
+ );
332
+ }
333
+ }
334
+ } else {
335
+ // Use `st.shared` if STSM is not available
336
+ #pragma unroll
337
+ for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
338
+ auto m_offset = local_idx * WAVE_BLOCK_M;
339
+ auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
340
+ auto smem_d_0 = reinterpret_cast<float2*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 0) * BLOCK_N + (lane_idx % 4) * 2);
341
+ auto smem_d_1 = reinterpret_cast<float2*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 8) * BLOCK_N + (lane_idx % 4) * 2);
342
+ #pragma unroll
343
+ for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
344
+ st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]));
345
+ st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]));
346
+ }
347
+ }
348
+ }
349
+ cute::tma_store_fence();
350
+ cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0);
351
+
352
+ // Use TMA store to write back to global memory
353
+ const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
354
+ DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
355
+ if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
356
+ auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
357
+ auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M;
358
+ if constexpr (kGemmType == GemmType::Batched) {
359
+ cute::SM90_TMA_STORE_3D::copy(&tensor_map_cd, smem_ptr,
360
+ n_block_idx * BLOCK_N + in_block_n_offset,
361
+ m_idx, scheduler.current_group_idx);
362
+ } else {
363
+ using cute_tma_t = cute::conditional_t<kWithAccumulation,
364
+ cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
365
+ cute_tma_t::copy(&tensor_map_cd, smem_ptr,
366
+ n_block_idx * BLOCK_N + in_block_n_offset, m_idx);
367
+ }
368
+ cute::tma_store_arrive();
369
+ }
370
+ __syncwarp();
371
+ }
372
+ }
373
+ #else
374
+ if (blockIdx.x == 0 and threadIdx.x == 0)
375
+ DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
376
+ #endif
377
+ }
378
+
379
+ }; // namespace deep_gemm
380
+
381
+ #pragma clang diagnostic pop
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cute/arch/cluster_sm90.hpp>
4
+ #include <cutlass/arch/barrier.h>
5
+ #include <cutlass/arch/reg_reconfig.h>
6
+
7
+ #include <deep_gemm/common/utils.cuh>
8
+ #include <deep_gemm/common/sm90_utils.cuh>
9
+
10
+ namespace deep_gemm {
11
+
12
+ using namespace deep_gemm::sm90;
13
+
14
+ template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
15
+ uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
16
+ uint32_t kSplitFactor,
17
+ uint32_t kNumStages,
18
+ uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
19
+ __global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
20
+ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
21
+ const __grid_constant__ cute::TmaDescriptor tensor_map_a,
22
+ const __grid_constant__ cute::TmaDescriptor tensor_map_b,
23
+ float *d) {
24
+ #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
25
+ // Types
26
+ using WGMMA = typename BF16MMASelector<BLOCK_N>::type;
27
+ using Barrier = cutlass::arch::ClusterTransactionBarrier;
28
+ DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
29
+
30
+ // Shared memory
31
+ static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16);
32
+ static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16);
33
+
34
+ // Configs
35
+ const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
36
+ const uint32_t lane_idx = get_lane_idx();
37
+ DG_STATIC_ASSERT(BLOCK_M == 128, "Invalid block M");
38
+ DG_STATIC_ASSERT(kNumTMAThreads == 128, "Invalid number of TMA threads");
39
+ DG_STATIC_ASSERT(kNumMathThreads == 256, "Invalid number of math threads");
40
+
41
+ // Prefetch TMA descriptors at the very beginning
42
+ if (warp_idx == 0 and cute::elect_one_sync()) {
43
+ cute::prefetch_tma_descriptor(&tensor_map_a);
44
+ cute::prefetch_tma_descriptor(&tensor_map_b);
45
+ }
46
+ __syncwarp();
47
+
48
+ // Align to 1024 bytes for swizzle-128B
49
+ // Fill shared memory pointers
50
+ extern __shared__ __align__(1024) uint8_t smem_buffer[];
51
+ auto smem_a = PatternVisitor([&](const uint32_t& i) {
52
+ return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (i * SMEM_A_SIZE_PER_STAGE));
53
+ });
54
+ auto smem_b = PatternVisitor([&](const uint32_t& i) {
55
+ return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
56
+ });
57
+
58
+ // Fill barriers
59
+ auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
60
+ auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
61
+ auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
62
+
63
+ // Initialize barriers
64
+ if (warp_idx == 1 and cute::elect_one_sync()) {
65
+ #pragma unroll
66
+ for (uint32_t i = 0; i < kNumStages; ++ i) {
67
+ full_barriers[i]->init(1);
68
+ empty_barriers[i]->init(kNumMathThreads);
69
+ }
70
+
71
+ // Make initialized barrier visible in async proxy
72
+ cutlass::arch::fence_barrier_init();
73
+ }
74
+
75
+ // Synchronize all threads to make barrier visible in normal memory model
76
+ __syncthreads();
77
+
78
+ // Register reconfigurations
79
+ constexpr uint32_t kNumTMARegisters = 40;
80
+ constexpr uint32_t kNumMathRegisters = 232;
81
+
82
+ // Block indices
83
+ const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N);
84
+ const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M);
85
+ const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks;
86
+ const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks;
87
+ const uint32_t n_block_idx = mn_block_idx % num_n_blocks;
88
+ const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
89
+ const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
90
+
91
+ if (warp_idx >= kNumMathThreads / 32) {
92
+ // TMA warp-group for loading data
93
+ cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
94
+
95
+ // NOTES: only one thread (or warp) will be used
96
+ if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
97
+ // Persistently schedule over blocks
98
+ #pragma unroll
99
+ for (uint32_t s = 0; s < num_total_stages; ++ s) {
100
+ // Wait consumer release
101
+ const auto& stage_idx = s % kNumStages;
102
+ empty_barriers[stage_idx]->wait((s / kNumStages + 1) & 1);
103
+
104
+ auto& full_barrier = *full_barriers[stage_idx];
105
+ const uint32_t& sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K;
106
+ const uint32_t& k_idx = sk_idx % SHAPE_K;
107
+ const uint32_t& s_idx = sk_idx / SHAPE_K;
108
+
109
+ constexpr uint32_t kSwizzle = BLOCK_K * sizeof(nv_bfloat16);
110
+ tma_copy<BLOCK_K, BLOCK_M, kSwizzle>(
111
+ &tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1);
112
+ tma_copy<BLOCK_K, BLOCK_N, kSwizzle>(
113
+ &tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1);
114
+ full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
115
+ }
116
+ }
117
+ } else {
118
+ // Math warp-groups for WGMMA
119
+ cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
120
+
121
+ // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
122
+ const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
123
+ float accum[WGMMA::kNumAccum] = {0};
124
+
125
+ // Launch MMAs
126
+ for (uint32_t s = 0; s < num_total_stages; ++ s) {
127
+ // Wait TMA arrivals
128
+ const auto& stage_idx = s % kNumStages;
129
+ full_barriers[stage_idx]->wait((s / kNumStages) & 1);
130
+
131
+ // Commit WGMMA instructions
132
+ #pragma unroll
133
+ for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
134
+ warpgroup_fence_operand(accum[i]);
135
+ warpgroup_arrive();
136
+ #pragma unroll
137
+ for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
138
+ auto desc_a = make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1);
139
+ auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
140
+ WGMMA::wgmma(desc_a, desc_b, accum, 1);
141
+ }
142
+ warpgroup_commit_batch();
143
+ #pragma unroll
144
+ for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
145
+ warpgroup_fence_operand(accum[i]);
146
+ warpgroup_wait<0>();
147
+
148
+ // Notify barrier arrival at the last warpgroup wave
149
+ empty_barriers[stage_idx]->arrive();
150
+ }
151
+
152
+ const auto& row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4;
153
+ const auto& col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2;
154
+ #pragma unroll
155
+ for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
156
+ if (col + i * 8 >= SHAPE_N)
157
+ break;
158
+ if (row < SHAPE_M) {
159
+ atomicAdd(reinterpret_cast<float2*>(d + (row + 0) * SHAPE_N + col + i * 8),
160
+ make_float2(accum[i * 4 + 0], accum[i * 4 + 1]));
161
+ }
162
+ if (row + 8 < SHAPE_M) {
163
+ atomicAdd(reinterpret_cast<float2*>(d + (row + 8) * SHAPE_N + col + i * 8),
164
+ make_float2(accum[i * 4 + 2], accum[i * 4 + 3]));
165
+ }
166
+ }
167
+ }
168
+ #else
169
+ if (blockIdx.x == 0 and threadIdx.x == 0)
170
+ DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
171
+ #endif
172
+ }
173
+
174
+ }; // namespace deep_gemm
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #pragma clang diagnostic push
4
+ #pragma clang diagnostic ignored "-Wunknown-attributes"
5
+
6
+ #include <cutlass/arch/barrier.h>
7
+ #include <cutlass/arch/reg_reconfig.h>
8
+
9
+ #include <cute/arch/cluster_sm90.hpp>
10
+ #include <cute/arch/copy_sm90_desc.hpp>
11
+ #include <cute/arch/copy_sm90_tma.hpp>
12
+
13
+ #include <deep_gemm/common/utils.cuh>
14
+ #include <deep_gemm/common/scheduler.cuh>
15
+ #include <deep_gemm/common/sm90_utils.cuh>
16
+
17
+ namespace deep_gemm {
18
+
19
+ using namespace deep_gemm::sm90;
20
+
21
+ template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
22
+ uint32_t kNumGroups,
23
+ uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
24
+ uint32_t kSwizzleAMode, uint32_t kSwizzleBMode,
25
+ uint32_t kNumStages,
26
+ uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
27
+ uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
28
+ uint32_t kNumSMs,
29
+ GemmType kGemmType, typename cd_dtype_t>
30
+ __global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
31
+ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
32
+ int* grouped_layout,
33
+ cute::TmaDescriptor* tensor_map_buffer,
34
+ uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
35
+ const __grid_constant__ cute::TmaDescriptor tensor_map_a_base,
36
+ const __grid_constant__ cute::TmaDescriptor tensor_map_b_base,
37
+ const __grid_constant__ cute::TmaDescriptor tensor_map_sfa,
38
+ const __grid_constant__ cute::TmaDescriptor tensor_map_sfb,
39
+ const __grid_constant__ cute::TmaDescriptor tensor_map_cd) {
40
+ #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
41
+ // Scaling checks
42
+ DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid Threads");
43
+ DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
44
+ DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
45
+ DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type");
46
+
47
+ // Types
48
+ using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
49
+ using Barrier = cutlass::arch::ClusterTransactionBarrier;
50
+ DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
51
+
52
+ // Overwrite shape constants if the compiler gives
53
+ shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
54
+ shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
55
+ shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
56
+
57
+ // Shared memory
58
+ static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 4 : 0);
59
+ static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float);
60
+ static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
61
+ static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
62
+ static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
63
+ static constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = BLOCK_N * sizeof(float);
64
+ static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u);
65
+ DG_STATIC_ASSERT(SMEM_SFA_SIZE_PER_STAGE % 128 == 0, "Invalid TMA alignment");
66
+
67
+ // Configs
68
+ const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
69
+ const uint32_t lane_idx = threadIdx.x % 32;
70
+
71
+ // Prefetch TMA descriptors at the very beginning
72
+ if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
73
+ cute::prefetch_tma_descriptor(&tensor_map_a_base);
74
+ cute::prefetch_tma_descriptor(&tensor_map_b_base);
75
+ cute::prefetch_tma_descriptor(&tensor_map_sfa);
76
+ cute::prefetch_tma_descriptor(&tensor_map_sfb);
77
+ cute::prefetch_tma_descriptor(&tensor_map_cd);
78
+ }
79
+ __syncwarp();
80
+
81
+ // Align to 1024 bytes for swizzle-128B
82
+ extern __shared__ __align__(1024) uint8_t smem_buffer[];
83
+ DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
84
+
85
+ // Tensor maps on shared and global memory
86
+ auto smem_tensor_map_a = PatternVisitor([&](const uint32_t& i) {
87
+ return reinterpret_cast<cute::TmaDescriptor*>(smem_buffer + static_cast<uint32_t>(sizeof(cute::TmaDescriptor)) * i);
88
+ });
89
+ auto smem_tensor_map_b = PatternVisitor([&](const uint32_t& i) {
90
+ return reinterpret_cast<cute::TmaDescriptor*>(smem_buffer + static_cast<uint32_t>(sizeof(cute::TmaDescriptor)) * (2 + i));
91
+ });
92
+ auto gmem_tensor_map_a = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + i; });
93
+ auto gmem_tensor_map_b = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + 2 + i; });
94
+
95
+ // Data on shared memory
96
+ auto smem_d = reinterpret_cast<float*>(smem_buffer + SMEM_TENSOR_MAP_SIZE);
97
+ auto smem_a = PatternVisitor([&](const uint32_t& i) {
98
+ return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE));
99
+ });
100
+ auto smem_b = PatternVisitor([&](const uint32_t& i) {
101
+ return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
102
+ });
103
+ constexpr auto SMEM_SF_OFFSET = SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
104
+ auto smem_sfa = PatternVisitor([&](const uint32_t& i) {
105
+ return reinterpret_cast<float*>(smem_buffer + (SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE));
106
+ });
107
+ auto smem_sfb = PatternVisitor([&](const uint32_t& i) {
108
+ return reinterpret_cast<float*>(smem_buffer + (SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * ALIGNED_SMEM_SFB_SIZE_PER_STAGE));
109
+ });
110
+
111
+ // Barriers on shared memory
112
+ constexpr auto SMEM_BARRIER_OFFSET = SMEM_SF_OFFSET + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + ALIGNED_SMEM_SFB_SIZE_PER_STAGE);
113
+ auto full_barriers = PatternVisitor([&](const uint32_t& i) {
114
+ return reinterpret_cast<Barrier*>(smem_buffer + (SMEM_BARRIER_OFFSET + i * static_cast<uint32_t>(sizeof(Barrier))));
115
+ });
116
+ auto empty_barriers = PatternVisitor([&](const uint32_t& i) {
117
+ return reinterpret_cast<Barrier*>(smem_buffer + (SMEM_BARRIER_OFFSET + (kNumStages + i) * static_cast<uint32_t>(sizeof(Barrier))));
118
+ });
119
+
120
+ if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
121
+ // Load tensormap A/B to shared memory
122
+ if constexpr (kGemmType == GemmType::KGroupedContiguous) {
123
+ *smem_tensor_map_a[0] = tensor_map_a_base;
124
+ *smem_tensor_map_a[1] = tensor_map_a_base;
125
+ *smem_tensor_map_b[0] = tensor_map_b_base;
126
+ *smem_tensor_map_b[1] = tensor_map_b_base;
127
+ }
128
+
129
+ // Initialize barriers
130
+ // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
131
+ // even with TMA multicast disabled, we want to make the behavior aligned
132
+ #pragma unroll
133
+ for (uint32_t i = 0; i < kNumStages; ++ i) {
134
+ full_barriers[i]->init(1);
135
+ empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
136
+ }
137
+
138
+ // Make initialized barrier visible in async proxy
139
+ cutlass::arch::fence_barrier_init();
140
+ }
141
+
142
+ // Synchronize all threads to make barrier visible in normal memory model
143
+ (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
144
+
145
+ // Pipeline unroll control
146
+ constexpr uint32_t kNumPipelineUnrolls = (kGemmType == GemmType::KGroupedContiguous ? 0 : kNumStages);
147
+
148
+ // Register reconfigurations (more math registers are needed with unrolling)
149
+ constexpr uint32_t kNumTMARegisters = (kNumPipelineUnrolls == 0 ? 40 : 24);
150
+ constexpr uint32_t kNumMathRegisters = (kNumPipelineUnrolls == 0 ? 232 : 240);
151
+
152
+ // Block scheduler
153
+ uint32_t m_block_idx, n_block_idx;
154
+ auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs, 128u>(shape_m, shape_n, shape_k, grouped_layout);
155
+
156
+ // TMA and MMA pipeline
157
+ const auto& get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple<uint32_t, uint32_t> {
158
+ return {iter_idx % kNumStages, (iter_idx / kNumStages) & 1}; // Pipeline stage and phase
159
+ };
160
+ uint32_t iter_idx = 0;
161
+
162
+ if (warp_idx >= kNumMathThreads / 32) {
163
+ // TMA warp-group for loading data
164
+ cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
165
+
166
+ // NOTES: only one thread (or warp) will be used
167
+ if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
168
+ const cute::TmaDescriptor* current_tensor_map_a = &tensor_map_a_base;
169
+ const cute::TmaDescriptor* current_tensor_map_b = &tensor_map_b_base;
170
+ uint32_t last_group_idx = kNumGroups, sum_k = 0;
171
+
172
+ // Persistently schedule over blocks
173
+ while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
174
+ // Assign TMA multicast number into A and B
175
+ // NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
176
+ const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
177
+ const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
178
+ const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
179
+ DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
180
+
181
+ const uint32_t& num_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
182
+ const uint32_t& m_idx = m_block_idx * BLOCK_M;
183
+ const uint32_t& n_idx = n_block_idx * BLOCK_N;
184
+
185
+ if (kGemmType == GemmType::KGroupedContiguous and last_group_idx != scheduler.current_group_idx) {
186
+ const uint32_t& stage_idx = scheduler.current_num_valid_groups & 1;
187
+ const uint32_t& next_stage_idx = stage_idx ^ 1;
188
+ last_group_idx = scheduler.current_group_idx;
189
+
190
+ // Prepare next tensor map
191
+ sum_k += scheduler.current_shape_k;
192
+ if (scheduler.next_group_idx < kNumGroups) {
193
+ tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + static_cast<uint64_t>(sum_k) * shape_m);
194
+ tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + static_cast<uint64_t>(sum_k) * shape_n);
195
+ tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k);
196
+ tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k);
197
+ *(gmem_tensor_map_a[next_stage_idx]) = *(smem_tensor_map_a[next_stage_idx]);
198
+ *(gmem_tensor_map_b[next_stage_idx]) = *(smem_tensor_map_b[next_stage_idx]);
199
+ tensor_map_release_cta();
200
+ }
201
+
202
+ // Get current tensor map
203
+ if (scheduler.current_num_valid_groups > 0) {
204
+ tensor_map_acquire_cta(gmem_tensor_map_a[stage_idx]);
205
+ tensor_map_acquire_cta(gmem_tensor_map_b[stage_idx]);
206
+ current_tensor_map_a = gmem_tensor_map_a[stage_idx];
207
+ current_tensor_map_b = gmem_tensor_map_b[stage_idx];
208
+ }
209
+ }
210
+
211
+ #pragma unroll kNumPipelineUnrolls
212
+ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) {
213
+ // Wait consumer release
214
+ CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase);
215
+ empty_barriers[stage_idx]->wait(phase ^ 1);
216
+
217
+ // Issue TMA
218
+ auto& full_barrier = *full_barriers[stage_idx];
219
+ const uint32_t& k_idx = k_block_idx * BLOCK_K;
220
+ const uint32_t& sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx;
221
+ tma_copy<BLOCK_M, BLOCK_K, 0>(&tensor_map_sfa, &full_barrier, smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a);
222
+ tma_copy<BLOCK_N, BLOCK_K, 0>(&tensor_map_sfb, &full_barrier, smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b);
223
+ tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(current_tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a);
224
+ tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(current_tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b);
225
+ full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE);
226
+ }
227
+ }
228
+
229
+ // To safely deconstruct distributed shared barriers, we need another round of empty waits
230
+ if constexpr (kNumTMAMulticast > 1) {
231
+ #pragma unroll
232
+ for (uint32_t s = 0; s < kNumStages; ++ s) {
233
+ CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase);
234
+ empty_barriers[stage_idx]->wait(phase ^ 1);
235
+ }
236
+ }
237
+ }
238
+ } else {
239
+ // Math warp-groups for WGMMA
240
+ cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
241
+
242
+ // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
243
+ const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
244
+ const auto row_idx = lane_idx / 4, col_idx = lane_idx % 4;
245
+ const auto r_0 = warp_idx * 16 + row_idx, r_1 = r_0 + 8;
246
+
247
+ // Persistently schedule over blocks
248
+ while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
249
+ // Accumulation for WGMMA or CUDA promotion
250
+ DG_STATIC_ASSERT(BLOCK_M == WGMMA::M * (BLOCK_M <= 64 ? 1 : 2), "Invalid block sizes");
251
+ const uint32_t& current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k);
252
+ const uint32_t& current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0);
253
+ const uint32_t& num_k_blocks = ceil_div(current_shape_k, BLOCK_K);
254
+ float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
255
+ float2 scales_b[WGMMA::kNumAccum / 4];
256
+
257
+ // Empty barrier arrival
258
+ auto empty_barrier_arrive = [&](uint32_t s) {
259
+ if constexpr (kNumTMAMulticast == 1) {
260
+ lane_idx == 0 ? empty_barriers[s]->arrive() : void();
261
+ } else {
262
+ auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
263
+ lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
264
+ }
265
+ };
266
+
267
+ #pragma unroll kNumPipelineUnrolls
268
+ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) {
269
+ // Wait TMA arrivals
270
+ CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase);
271
+ full_barriers[stage_idx]->wait(phase);
272
+
273
+ // Read A scales
274
+ // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
275
+ auto scale_a_0 = ld_shared(smem_sfa[stage_idx] + r_0);
276
+ auto scale_a_1 = ld_shared(smem_sfa[stage_idx] + r_1);
277
+
278
+ // Read B scales
279
+ #pragma unroll
280
+ for (int i = 0; i < WGMMA::kNumAccum / 4; ++i)
281
+ scales_b[i] = ld_shared(reinterpret_cast<float2*>(smem_sfb[stage_idx] + i * 8 + col_idx * 2));
282
+
283
+ // Commit WGMMA instructions
284
+ #pragma unroll
285
+ for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
286
+ warpgroup_fence_operand(accum[i]);
287
+ warpgroup_arrive();
288
+ #pragma unroll
289
+ for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
290
+ auto desc_a = make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
291
+ auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
292
+ WGMMA::wgmma(desc_a, desc_b, accum, k);
293
+ }
294
+ warpgroup_commit_batch();
295
+ #pragma unroll
296
+ for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
297
+ warpgroup_fence_operand(accum[i]);
298
+ warpgroup_wait<0>();
299
+
300
+ // Notify barrier arrival
301
+ empty_barrier_arrive(stage_idx);
302
+
303
+ // Promote with scales
304
+ #pragma unroll
305
+ for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
306
+ const float &scale_b_0 = scales_b[i].x;
307
+ const float &scale_b_1 = scales_b[i].y;
308
+ final_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0];
309
+ final_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1];
310
+ final_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2];
311
+ final_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3];
312
+ }
313
+ }
314
+
315
+ // Flush previous stores
316
+ if (warp_idx % 4 == 0 and cute::elect_one_sync())
317
+ cute::tma_store_wait<0>();
318
+ cutlass::arch::NamedBarrier::sync(128, math_wg_idx);
319
+
320
+ // Store to D shared memory
321
+ const auto& smem_d_0 = reinterpret_cast<float2*>(smem_d + r_0 * BLOCK_N + col_idx * 2);
322
+ const auto& smem_d_1 = reinterpret_cast<float2*>(smem_d + r_1 * BLOCK_N + col_idx * 2);
323
+ #pragma unroll
324
+ for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
325
+ st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]});
326
+ st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]});
327
+ }
328
+ cute::tma_store_fence();
329
+ cutlass::arch::NamedBarrier::sync(128, math_wg_idx);
330
+
331
+ // Use TMA store to write back to global memory
332
+ if (warp_idx % 4 == 0 and cute::elect_one_sync()) {
333
+ cute::SM90_TMA_REDUCE_ADD_2D::copy(
334
+ &tensor_map_cd, smem_d_0, n_block_idx * BLOCK_N,
335
+ current_group_idx * shape_m + m_block_idx * BLOCK_M + r_0);
336
+ cute::tma_store_arrive();
337
+ }
338
+ __syncwarp();
339
+ }
340
+ }
341
+ #else
342
+ if (blockIdx.x == 0 and threadIdx.x == 0)
343
+ DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
344
+ #endif
345
+ }
346
+
347
+ }; // namespace deep_gemm
348
+
349
+ #pragma clang diagnostic pop
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #pragma clang diagnostic push
4
+ #pragma clang diagnostic ignored "-Wunknown-attributes"
5
+
6
+ #include <cutlass/arch/barrier.h>
7
+ #include <cutlass/arch/reg_reconfig.h>
8
+
9
+ #include <cute/arch/cluster_sm90.hpp>
10
+ #include <cute/arch/copy_sm90_desc.hpp>
11
+ #include <cute/arch/copy_sm90_tma.hpp>
12
+
13
+ #include <deep_gemm/common/epilogue_utils.cuh>
14
+ #include <deep_gemm/common/utils.cuh>
15
+ #include <deep_gemm/common/scheduler.cuh>
16
+ #include <deep_gemm/common/sm90_utils.cuh>
17
+
18
+ namespace deep_gemm {
19
+
20
+ using namespace deep_gemm::sm90;
21
+
22
+ template <uint32_t kNumFormerIters, uint32_t kGap, uint32_t kEnd, typename func_t>
23
+ __device__ void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) {
24
+ if (num_former_iters == kNumFormerIters) {
25
+ func(cute::Int<kNumFormerIters>{});
26
+ return;
27
+ }
28
+
29
+ if constexpr (kNumFormerIters + kGap <= kEnd)
30
+ dispatch_num_former_iters<kNumFormerIters + kGap, kGap, kEnd>(num_former_iters, func);
31
+ }
32
+
33
+ template <cute::UMMA::Major kMajorSFB,
34
+ uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
35
+ uint32_t kNumGroups,
36
+ uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
37
+ uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleDMode,
38
+ uint32_t kNumStages, uint32_t kNumLastStages,
39
+ uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
40
+ uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
41
+ uint32_t kNumSMs, GemmType kGemmType,
42
+ typename epilogue_type_t>
43
+ __global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
44
+ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
45
+ uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
46
+ const __grid_constant__ cute::TmaDescriptor tensor_map_a,
47
+ const __grid_constant__ cute::TmaDescriptor tensor_map_b,
48
+ const __grid_constant__ cute::TmaDescriptor tensor_map_d,
49
+ const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) {
50
+ #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
51
+ // Scaling checks
52
+ DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
53
+ DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block");
54
+
55
+ // Types
56
+ using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
57
+ using Barrier = cutlass::arch::ClusterTransactionBarrier;
58
+ DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size");
59
+
60
+ // Overwrite shape constants if the compiler gives
61
+ shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
62
+ shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
63
+ shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
64
+
65
+ // Shared memory
66
+ static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
67
+ static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(__nv_bfloat16)), 1024u);
68
+ static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
69
+ static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
70
+ static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
71
+ static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u);
72
+ const uint32_t& shape_k_scales = ceil_div(shape_k, BLOCK_K);
73
+ const uint32_t& shape_n_sfb = ceil_div(shape_n, BLOCK_K);
74
+ const uint32_t& smem_sfb_size = align<uint32_t>(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier));
75
+
76
+ // NOTES: Make sure we have enough shared memory for WGMMA padding
77
+ static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3);
78
+ DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA");
79
+
80
+ // Configs
81
+ const uint32_t num_total_k_blocks = ceil_div(shape_k, BLOCK_K);
82
+ const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
83
+ const uint32_t lane_idx = get_lane_idx();
84
+
85
+ // Prefetch TMA descriptors at the very beginning
86
+ if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
87
+ cute::prefetch_tma_descriptor(&tensor_map_a);
88
+ cute::prefetch_tma_descriptor(&tensor_map_b);
89
+ cute::prefetch_tma_descriptor(&tensor_map_sfa);
90
+ cute::prefetch_tma_descriptor(&tensor_map_d);
91
+ }
92
+ __syncwarp();
93
+
94
+ // Align to 1024 bytes for swizzle-128B
95
+ extern __shared__ __align__(1024) uint8_t smem_buffer[];
96
+ DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
97
+
98
+ // Data on shared memory
99
+ auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
100
+ auto smem_a = PatternVisitor([&](const uint32_t& i) {
101
+ return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
102
+ });
103
+ auto smem_b = PatternVisitor([&](const uint32_t& i) {
104
+ return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
105
+ });
106
+ constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
107
+ auto smem_sfa = PatternVisitor([&](const uint32_t& i) {
108
+ return reinterpret_cast<float*>(smem_buffer + SMEM_SF_OFFSET + i * ALIGNED_SMEM_SFA_SIZE_PER_STAGE);
109
+ });
110
+ auto smem_sfb = reinterpret_cast<float*>(smem_buffer + SMEM_SF_OFFSET + kNumStages * ALIGNED_SMEM_SFA_SIZE_PER_STAGE);
111
+
112
+ // Fill barriers
113
+ auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_sfb) + smem_sfb_size);
114
+ auto full_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; });
115
+ auto empty_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; });
116
+
117
+ // Initialize barriers
118
+ DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast");
119
+ if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
120
+ // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
121
+ // even with TMA multicast disabled, we want to make the behavior aligned
122
+ #pragma unroll
123
+ for (uint32_t i = 0; i < kNumStages; ++ i) {
124
+ full_barriers[i]->init(1);
125
+ empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
126
+ }
127
+
128
+ // Make initialized barrier visible in async proxy
129
+ cutlass::arch::fence_barrier_init();
130
+ }
131
+
132
+ // Synchronize all threads to make barrier visible in normal memory model
133
+ (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
134
+
135
+ // Register reconfigurations
136
+ constexpr uint32_t kNumTMARegisters = 40;
137
+ constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 232;
138
+
139
+ // Block scheduler
140
+ uint32_t m_block_idx, n_block_idx;
141
+ auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
142
+
143
+ // Pipeline and TMA phases
144
+ uint32_t stage_idx = 0, phase = 0;
145
+ auto advance_pipeline = [&](uint32_t& k_block_idx) {
146
+ ++ k_block_idx;
147
+
148
+ // Flip phases only if reach the next first stage
149
+ stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1;
150
+ phase ^= stage_idx == 0;
151
+ };
152
+
153
+ if (warp_idx >= kNumMathThreads / 32) {
154
+ // TMA warp-group for loading data
155
+ cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
156
+
157
+ // NOTES: only one thread (or warp) will be used
158
+ // We use the third warp, as warp 0/1 may be doing WGMMA with `BLOCK_M == 32`
159
+ if (warp_idx == kNumMathThreads / 32 + 2 and cute::elect_one_sync()) {
160
+ // Persistently schedule over blocks
161
+ while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
162
+ // Assign TMA multicast number into A and B
163
+ // NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
164
+ const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
165
+ const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
166
+ const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
167
+ DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
168
+
169
+ for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
170
+ // Wait consumer release
171
+ empty_barriers[stage_idx]->wait(phase ^ 1);
172
+
173
+ // Issue TMA A
174
+ constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
175
+ const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
176
+
177
+ constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
178
+ auto& full_barrier = *full_barriers[stage_idx];
179
+ const uint32_t k_idx = k_block_idx * BLOCK_K;
180
+ tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode, __nv_fp8_e4m3, kIsBatchedMM>(&tensor_map_a, &full_barrier,
181
+ smem_a[stage_idx], k_idx, scheduler.get_global_idx<kWithGroupOffsetA>(shape_m, BLOCK_M, m_block_idx),
182
+ num_tma_multicast_a, batch_idx);
183
+ tma_copy<BLOCK_M, BLOCK_K, 0>(&tensor_map_sfa, &full_barrier,
184
+ smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx<kWithGroupOffsetA, IndexType::SF_K>(shape_k_scales, 1, k_block_idx),
185
+ num_tma_multicast_a);
186
+
187
+ // Issue TMA B
188
+ tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode, __nv_fp8_e4m3, kIsBatchedMM>(&tensor_map_b, &full_barrier,
189
+ smem_b[stage_idx], k_idx, scheduler.get_global_idx<true>(shape_n, BLOCK_N, n_block_idx, m_block_idx),
190
+ num_tma_multicast_b, batch_idx);
191
+ full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE);
192
+ }
193
+ }
194
+
195
+ // To safely deconstruct distributed shared barriers, we need another round of empty waits
196
+ if constexpr (kNumTMAMulticast > 1) {
197
+ for (uint32_t i = 0; i < kNumStages; advance_pipeline(i))
198
+ empty_barriers[stage_idx]->wait(phase ^ 1);
199
+ }
200
+ }
201
+ } else {
202
+ // Math warp-groups for WGMMA
203
+ cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
204
+
205
+ // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
206
+ const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
207
+ const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
208
+
209
+ auto a_desc = make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1);
210
+ auto b_desc = make_smem_desc(smem_b[0], 1);
211
+ const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0);
212
+ const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0);
213
+
214
+ // Persistently schedule over blocks
215
+ while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
216
+ // Decide the number of scales B to load
217
+ DG_TRAP_ONLY_DEVICE_ASSERT(shape_n % 8 == 0);
218
+ uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters;
219
+ if constexpr (not kMustUseUniformedScaleB) {
220
+ num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8;
221
+ num_full_iters = min(shape_n - n_block_idx * BLOCK_N, BLOCK_N) / 8;
222
+ }
223
+ uint32_t num_sfb = shape_k_scales * (num_former_iters >= num_full_iters ? 1 : 2);
224
+
225
+ // Load B scales with math warp-groups
226
+ // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
227
+ if (threadIdx.x >= 32) {
228
+ auto previous_group_offset = scheduler.template get_global_idx<true, IndexType::SF_K>(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx);
229
+ const uint32_t stride_n_sfb = kMajorSFB == cute::UMMA::Major::MN ? 1 : shape_k_scales;
230
+ const uint32_t stride_k_sfb = kMajorSFB == cute::UMMA::Major::MN ? shape_n_sfb : 1;
231
+ auto local_sfb = sfb + previous_group_offset + ((n_block_idx * BLOCK_N) / BLOCK_K) * stride_n_sfb;
232
+
233
+ #pragma unroll
234
+ for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32)
235
+ st_shared(smem_sfb + i, __ldg(i < shape_k_scales ? local_sfb + i * stride_k_sfb : local_sfb + (i - shape_k_scales) * stride_k_sfb + stride_n_sfb));
236
+ }
237
+ cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
238
+
239
+ // Accumulation for WGMMA or CUDA promotion
240
+ constexpr uint32_t WAVE_BLOCK_M = BLOCK_M <= WGMMA::M ? BLOCK_M : WGMMA::M * 2;
241
+ DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes");
242
+ float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0};
243
+
244
+ // Pick threads whose WGMMA results are to be stored in shared memory
245
+ DG_STATIC_ASSERT(BLOCK_M >= 64 or kNumMathThreads == 128, "Only one math warp group for `BLOCK_M < 64`");
246
+ constexpr uint32_t kNumWGMMAStoreThreads = WAVE_BLOCK_M * (128 / WGMMA::M);
247
+ const bool do_wgmma_store = BLOCK_M >= WGMMA::M or warp_idx < kNumWGMMAStoreThreads / 32;
248
+
249
+ // Empty barrier arrival
250
+ auto empty_barrier_arrive = [&]() {
251
+ if constexpr (kNumTMAMulticast == 1) {
252
+ lane_idx == 0 ? empty_barriers[stage_idx]->arrive() : void();
253
+ } else {
254
+ auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
255
+ lane_idx < kNumTMAMulticast ? empty_barriers[stage_idx]->arrive(target_cta) : void();
256
+ }
257
+ };
258
+
259
+ // Skip useless computations
260
+ if (scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M)) {
261
+ // The compiler must know the dynamic variable `num_former_iters`'s real value
262
+ constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
263
+ constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
264
+ constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
265
+
266
+ // Dispatch `num_former_iters` and launch MMAs
267
+ dispatch_num_former_iters<0, kGap, kEnd>(kShouldOptimize ? num_former_iters : 0, [&](auto _) {
268
+ #pragma unroll 8
269
+ for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
270
+ const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16);
271
+ const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16);
272
+
273
+ // Read B scales
274
+ float scale_b_0 = ld_shared(smem_sfb + k_block_idx), scale_b_1;
275
+ // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
276
+ if constexpr (not kMustUseUniformedScaleB)
277
+ scale_b_1 = ld_shared(smem_sfb + k_block_idx + shape_k_scales);
278
+
279
+ // Wait TMA arrivals
280
+ full_barriers[stage_idx]->wait(phase);
281
+
282
+ // TODO: remove some useless computation for unaligned Ms
283
+ #pragma unroll
284
+ for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
285
+ auto m_offset = local_idx * WAVE_BLOCK_M;
286
+
287
+ // Read A scales
288
+ // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
289
+ auto scale_a_0 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0;
290
+ auto scale_a_1 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0;
291
+
292
+ // Commit WGMMA instructions
293
+ #pragma unroll
294
+ for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
295
+ warpgroup_fence_operand(accum[i]);
296
+ warpgroup_arrive();
297
+ #pragma unroll
298
+ for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
299
+ a_desc.reg32_[0] = a_desc_base_lo + (m_offset * BLOCK_K + k * WGMMA::K) / 16;
300
+ b_desc.reg32_[0] = b_desc_base_lo + k * WGMMA::K / 16;
301
+ WGMMA::wgmma(a_desc, b_desc, accum, k);
302
+ }
303
+ warpgroup_commit_batch();
304
+ #pragma unroll
305
+ for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
306
+ warpgroup_fence_operand(accum[i]);
307
+ warpgroup_wait<0>();
308
+
309
+ // Notify barrier arrival at the last warpgroup wave
310
+ if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
311
+ empty_barrier_arrive();
312
+
313
+ // Skip promotion for the unfilled parts
314
+ if (not do_wgmma_store)
315
+ continue;
316
+
317
+ // Promote with scales
318
+ // NOTES: making it as predicates is very important for performance, comparing to two loops
319
+ float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
320
+ float scale_0_1, scale_1_1;
321
+ if constexpr (not kMustUseUniformedScaleB)
322
+ scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
323
+
324
+ auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
325
+ #pragma unroll
326
+ for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
327
+ // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
328
+ const bool& predicate = kMustUseUniformedScaleB or i < num_former_iters;
329
+ shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
330
+ shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
331
+ shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
332
+ shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
333
+ }
334
+ }
335
+ }
336
+ });
337
+ } else {
338
+ #pragma unroll
339
+ for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
340
+ full_barriers[stage_idx]->wait(phase);
341
+ empty_barrier_arrive();
342
+ }
343
+ }
344
+
345
+ // TMA checks
346
+ constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
347
+ constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes);
348
+ constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4;
349
+ DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom");
350
+ DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32,
351
+ "Unaligned TMA store or too many TMA store instructions");
352
+ DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N");
353
+
354
+ // Skip WGMMA store for the unfilled parts
355
+ if (not do_wgmma_store)
356
+ continue;
357
+
358
+ // Wait last TMA store to be finished
359
+ if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N)
360
+ cute::tma_store_wait<0>();
361
+ cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 1);
362
+
363
+ // Write back to shared memory using STSM and issue TMA stores
364
+ DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
365
+ #pragma unroll
366
+ for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
367
+ auto m_offset = local_idx * WAVE_BLOCK_M;
368
+ auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
369
+ #pragma unroll
370
+ for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
371
+ // Swizzle or padding into the correct address
372
+ uint8_t* smem_ptr = nullptr;
373
+ if constexpr (kSwizzleDMode > 0) {
374
+ // Calculate the swizzling atom offset and in-atom offset
375
+ constexpr uint32_t kNumBankGroupBytes = 16;
376
+ auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
377
+
378
+ // Calculate the index of the bank group to be written in the atom
379
+ auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
380
+
381
+ // Reshape the atom in another view and swizzle
382
+ // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)`
383
+ // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)`
384
+ constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8;
385
+ auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8);
386
+ auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8);
387
+ col ^= row % (kSwizzleDMode / 16);
388
+
389
+ // Add back into the base pointer
390
+ // NOTES: think twice before modifying this, as changes may affect the number of instructions
391
+ smem_ptr = reinterpret_cast<uint8_t*>(smem_d) + // Base pointer
392
+ warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset
393
+ m_offset * kSwizzleDMode + // Wave offset
394
+ atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants)
395
+ row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
396
+ } else {
397
+ // No swizzling, just padding
398
+ smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8);
399
+ }
400
+
401
+ // NOTES: only 16 lanes' addresses are used
402
+ SM90_U32x2_STSM_N<nv_bfloat162>::copy(
403
+ __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
404
+ __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
405
+ smem_ptr
406
+ );
407
+ }
408
+ }
409
+ cute::tma_store_fence();
410
+ cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 1);
411
+
412
+ // Use TMA store to write back to global memory
413
+ // TODO: compatible with FP32 output
414
+ constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked;
415
+ DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
416
+ if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
417
+ auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
418
+ auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M;
419
+ auto n_idx = epilogue_type_t::apply_index_n<TMA_D_BLOCK_N>(n_block_idx * BLOCK_N + in_block_n_offset);
420
+ auto m_idx = scheduler.get_global_idx<kWithGroupOffsetD>(shape_m, BLOCK_M, m_block_idx);
421
+ if constexpr (kGemmType == GemmType::Batched) {
422
+ cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_ptr,
423
+ n_idx, m_idx, scheduler.current_group_idx);
424
+ } else {
425
+ cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, n_idx, m_idx);
426
+ }
427
+ cute::tma_store_arrive();
428
+ }
429
+ __syncwarp();
430
+ }
431
+ }
432
+ #else
433
+ if (blockIdx.x == 0 and threadIdx.x == 0)
434
+ DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
435
+ #endif
436
+ }
437
+
438
+ }; // namespace deep_gemm
439
+
440
+ #pragma clang diagnostic pop
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cutlass/arch/barrier.h>
4
+ #include <cutlass/arch/reg_reconfig.h>
5
+
6
+ #include <cute/arch/cluster_sm90.hpp>
7
+ #include <cute/arch/copy_sm90_desc.hpp>
8
+ #include <cute/arch/mma_sm90_desc.hpp>
9
+
10
+ #include <deep_gemm/common/utils.cuh>
11
+ #include <deep_gemm/common/sm90_utils.cuh>
12
+
13
+ namespace deep_gemm {
14
+
15
+ using namespace deep_gemm::sm90;
16
+
17
+ // ReSharper disable once CppNotAllPathsReturnValue
18
+ template <uint32_t kHeadDim>
19
+ static constexpr int to_swizzle_cute_type() {
20
+ DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling");
21
+ if constexpr (kHeadDim == 32)
22
+ return static_cast<int>(cute::SM90::GMMA::LayoutType::B32);
23
+ if constexpr (kHeadDim == 64)
24
+ return static_cast<int>(cute::SM90::GMMA::LayoutType::B64);
25
+ if constexpr (kHeadDim == 128)
26
+ return static_cast<int>(cute::SM90::GMMA::LayoutType::B128);
27
+ }
28
+
29
+ template <uint32_t kNumHeads, uint32_t kHeadDim,
30
+ bool kIsCompressedLogits,
31
+ uint32_t BLOCK_Q, uint32_t BLOCK_KV,
32
+ uint32_t kNumQStages, uint32_t kNumKVStages,
33
+ uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
34
+ __global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
35
+ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
36
+ const uint32_t max_seqlen_k, const uint64_t stride_logits,
37
+ uint32_t* cu_seq_len_k_start,
38
+ uint32_t* cu_seq_len_k_end,
39
+ float* logits,
40
+ const __grid_constant__ cute::TmaDescriptor tensor_map_q,
41
+ const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
42
+ const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
43
+ const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
44
+ // TODO: consider TMA multicast
45
+ // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]`
46
+ // Q should be load only at once for a block
47
+ const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q);
48
+
49
+ // Types
50
+ using WGMMA = typename FP8MMASelector<BLOCK_Q * kNumHeads>::type;
51
+ using Barrier = cutlass::arch::ClusterTransactionBarrier;
52
+
53
+ // Prefetch TMA descriptors
54
+ DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
55
+ if (threadIdx.x / 32 == kNumMathThreads / 32 and cute::elect_one_sync()) {
56
+ cute::prefetch_tma_descriptor(&tensor_map_q);
57
+ cute::prefetch_tma_descriptor(&tensor_map_kv);
58
+ cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
59
+ cute::prefetch_tma_descriptor(&tensor_map_weights);
60
+ }
61
+ __syncwarp();
62
+
63
+ // Shared memory configs
64
+ // NOTES: weight may be unaligned
65
+ static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
66
+ static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
67
+ static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float);
68
+ static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
69
+ static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
70
+
71
+ // Align to swizzling alignment bytes
72
+ extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
73
+ DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
74
+ DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
75
+
76
+ // Data on shared memory
77
+ auto smem_q = PatternVisitor([&](const uint32_t& i) {
78
+ return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer +
79
+ SMEM_Q_SIZE_PER_STAGE * i);
80
+ });
81
+ auto smem_kv = PatternVisitor([&](const uint32_t& i) {
82
+ return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (
83
+ SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i));
84
+ });
85
+ auto smem_weights = PatternVisitor([&](const uint32_t& i) {
86
+ return reinterpret_cast<float*>(smem_buffer +
87
+ SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
88
+ });
89
+ auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
90
+ return reinterpret_cast<float*>(smem_buffer +
91
+ SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages +
92
+ SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SCALE_SIZE_PER_STAGE * i);
93
+ });
94
+
95
+ // TMA barriers
96
+ auto barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
97
+ auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
98
+ auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); });
99
+ auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); });
100
+ auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); });
101
+
102
+ // Initialize barriers
103
+ const bool& is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32;
104
+ if (is_tma_load_warp and cute::elect_one_sync()) {
105
+ #pragma unroll
106
+ for (uint32_t i = 0; i < kNumQStages; ++ i) {
107
+ full_q_barriers[i]->init(1);
108
+ empty_q_barriers[i]->init(kNumMathThreads);
109
+ }
110
+ #pragma unroll
111
+ for (uint32_t i = 0; i < kNumKVStages; ++ i) {
112
+ full_kv_barriers[i]->init(1);
113
+ empty_kv_barriers[i]->init(kNumMathThreads);
114
+ }
115
+
116
+ // Make initialized barrier visible in async proxy
117
+ cutlass::arch::fence_barrier_init();
118
+ }
119
+ __syncthreads();
120
+
121
+ // Register reconfigurations
122
+ constexpr uint32_t kNumTMARegisters = 32;
123
+ constexpr uint32_t kNumMathRegisters = 112;
124
+
125
+ // Block scheduler
126
+ uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0;
127
+ const auto& get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
128
+ return {block_q_idx + gridDim.x, q_iter_idx + 1};
129
+ };
130
+ uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
131
+ const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
132
+ uint32_t start = cute::numeric_limits<uint32_t>::max();
133
+ uint32_t end = cute::numeric_limits<uint32_t>::min();
134
+
135
+ #pragma unroll
136
+ for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
137
+ const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
138
+ seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx);
139
+ seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx);
140
+ start = min(start, min(seq_k_start[i], seq_len_kv));
141
+ end = max(end, min(seq_k_end[i], seq_len_kv));
142
+ }
143
+ start = start / 4 * 4;
144
+ return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
145
+ ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase
146
+ start, ceil_div(end - start, BLOCK_KV)}; // Task info
147
+ };
148
+
149
+ // KV pipeline
150
+ uint32_t num_total_kv_blocks = 0;
151
+ const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple<uint32_t, uint32_t> {
152
+ return {
153
+ (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage
154
+ ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase
155
+ };
156
+ };
157
+
158
+ if (threadIdx.x >= kNumMathThreads) {
159
+ // TMA warp-group for loading data
160
+ cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
161
+
162
+ // Only the first warp remains
163
+ if (not is_tma_load_warp)
164
+ return;
165
+
166
+ // Prefetch
167
+ const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) {
168
+ tma_copy<kHeadDim, BLOCK_Q * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads);
169
+ tma_copy<kNumHeads, BLOCK_Q, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q);
170
+ full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
171
+ };
172
+ if (cute::elect_one_sync() and block_q_idx < num_q_blocks)
173
+ issue_tma_q(0, block_q_idx);
174
+
175
+ // Only the first lane persistently schedules over blocks
176
+ if (cute::elect_one_sync()) {
177
+ while (block_q_idx < num_q_blocks) {
178
+ CUTE_TIE_DECL(load_schedule(1), q_stage_idx, q_phase, kv_start, num_kv_blocks);
179
+
180
+ // Wait Q consumer release
181
+ empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
182
+
183
+ // Issue TMA Q
184
+ if (const auto& next_block_q_idx = cute::get<0>(get_next_block_q_idx()); next_block_q_idx < num_q_blocks)
185
+ issue_tma_q(q_stage_idx, next_block_q_idx);
186
+
187
+ // Issue TMA KV
188
+ #pragma unroll
189
+ for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) {
190
+ // Wait consumer release
191
+ CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase);
192
+ empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
193
+
194
+ // Issue TMA KV
195
+ tma_copy<kHeadDim, BLOCK_KV, kHeadDim>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
196
+ smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV);
197
+ tma_copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
198
+ smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0);
199
+ full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
200
+ }
201
+ num_total_kv_blocks += num_kv_blocks;
202
+
203
+ // Jump to the next block
204
+ CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
205
+ }
206
+ }
207
+ } else {
208
+ // Math warp-groups for WGMMA
209
+ cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
210
+
211
+ // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
212
+ const auto& thread_idx = threadIdx.x % kNumMathThreads;
213
+ const auto& warp_idx = __shfl_sync(0xffffffff, thread_idx / 32, 0);
214
+ const auto& warpgroup_idx = warp_idx / 4;
215
+ const auto& lane_idx = get_lane_idx();
216
+ float accum[WGMMA::kNumAccum], weights[BLOCK_Q][kNumHeads / 4];
217
+
218
+ const auto& warp_offset = warp_idx * 16;
219
+ const auto& v_0_offset = lane_idx / 4 + 0;
220
+ const auto& v_1_offset = lane_idx / 4 + 8;
221
+
222
+ while (block_q_idx < num_q_blocks) {
223
+ CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks);
224
+
225
+ // Wait TMA Q arrival
226
+ full_q_barriers[q_stage_idx]->wait(q_phase);
227
+
228
+ // Read weights
229
+ #pragma unroll
230
+ for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
231
+ #pragma unroll
232
+ for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
233
+ weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
234
+ }
235
+
236
+ // Compute over KV blocks
237
+ #pragma unroll
238
+ for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) {
239
+ // Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]`
240
+ // Wait TMA KV arrival
241
+ CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase);
242
+ full_kv_barriers[kv_stage_idx]->wait(kv_phase);
243
+
244
+ // Read per-KV scales
245
+ float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset);
246
+ float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset);
247
+
248
+ // Issue WGMMA
249
+ DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size");
250
+ DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim");
251
+ #pragma unroll
252
+ for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
253
+ warpgroup_fence_operand(accum[i]);
254
+ warpgroup_arrive();
255
+ #pragma unroll
256
+ for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) {
257
+ auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K,
258
+ to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
259
+ auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K,
260
+ to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
261
+ WGMMA::wgmma(desc_a, desc_b, accum, k);
262
+ }
263
+ warpgroup_commit_batch();
264
+ #pragma unroll
265
+ for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
266
+ warpgroup_fence_operand(accum[i]);
267
+ warpgroup_wait<0>();
268
+
269
+ // Release KV empty
270
+ empty_kv_barriers[kv_stage_idx]->arrive();
271
+
272
+ // Reduce over the head dim and store
273
+ const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset;
274
+ static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2;
275
+ DG_STATIC_ASSERT(WGMMA::kNumAccum % kNumAccumPerReduce == 0, "Invalid accumulation");
276
+ DG_STATIC_ASSERT(WGMMA::kNumAccum / kNumAccumPerReduce == BLOCK_Q, "Invalid accumulation");
277
+ DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
278
+ #pragma unroll
279
+ for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
280
+ auto shifted_accum = accum + i * kNumAccumPerReduce;
281
+ const auto& transform = [&](const uint32_t& j) {
282
+ return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)];
283
+ };
284
+
285
+ // Intra-thread reduction
286
+ float sum[4] = {transform(0), transform(1), transform(2), transform(3)};
287
+ #pragma unroll
288
+ for (uint32_t j = 1; j < kNumHeads / 8; ++ j) {
289
+ #pragma unroll
290
+ for (uint32_t k = 0; k < 4; k ++)
291
+ sum[k] += transform(j * 4 + k);
292
+ }
293
+ float v_0 = (sum[0] + sum[1]) * scale_kv_0;
294
+ float v_1 = (sum[2] + sum[3]) * scale_kv_1;
295
+
296
+ // Inter-thread reduction
297
+ #pragma unroll
298
+ for (uint32_t j = 0; j < 2; ++ j) {
299
+ const auto& offset = static_cast<int>(1u << j);
300
+ v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset);
301
+ v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset);
302
+ }
303
+
304
+ // Store into the global memory
305
+ // NOTES: we have redundant writes here, consider more carefully
306
+ const uint32_t& q_idx = block_q_idx * BLOCK_Q + i;
307
+ if constexpr (kIsCompressedLogits) {
308
+ if (seq_k_start[i] <= kv_offset + v_0_offset and kv_offset + v_0_offset < seq_k_end[i])
309
+ logits[q_idx * stride_logits + kv_offset + v_0_offset - seq_k_start[i]] = v_0;
310
+ if (seq_k_start[i] <= kv_offset + v_1_offset and kv_offset + v_1_offset < seq_k_end[i])
311
+ logits[q_idx * stride_logits + kv_offset + v_1_offset - seq_k_start[i]] = v_1;
312
+ } else {
313
+ logits[q_idx * stride_logits + kv_offset + v_0_offset] = v_0;
314
+ logits[q_idx * stride_logits + kv_offset + v_1_offset] = v_1;
315
+ }
316
+ }
317
+ }
318
+ num_total_kv_blocks += num_kv_blocks;
319
+
320
+ // Release Q empty
321
+ empty_q_barriers[q_stage_idx]->arrive();
322
+
323
+ // Jump to the next block
324
+ CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
325
+ }
326
+ }
327
+ }
328
+
329
+ } // namespace deep_gemm
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cutlass/arch/barrier.h>
4
+ #include <cutlass/arch/reg_reconfig.h>
5
+
6
+ #include <cute/arch/cluster_sm90.hpp>
7
+ #include <cute/arch/copy_sm90_desc.hpp>
8
+
9
+ #include <deep_gemm/common/utils.cuh>
10
+ #include <deep_gemm/common/sm90_utils.cuh>
11
+ #include <deep_gemm/impls/sm90_fp8_mqa_logits.cuh>
12
+
13
+ namespace deep_gemm {
14
+
15
+ template <uint32_t kAlignedBatchSize, uint32_t SPLIT_KV, uint32_t kNumSMs>
16
+ __global__ __launch_bounds__(32, 1)
17
+ void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d,
18
+ const uint32_t* context_lens, uint32_t* schedule_metadata) {
19
+ DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size");
20
+ const uint32_t lane_idx = get_lane_idx();
21
+
22
+ uint32_t num_segs[kAlignedBatchSize / 32];
23
+ #pragma unroll
24
+ for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) {
25
+ const uint32_t q_idx = k * 32 + lane_idx;
26
+ const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx);
27
+ const uint32_t& context_len = (q_idx < batch_size ? __ldg(context_lens + lens_idx) : 0);
28
+ num_segs[k] = ceil_div(context_len, SPLIT_KV);
29
+ }
30
+
31
+ __shared__ uint32_t prefix_sum[kAlignedBatchSize];
32
+ uint32_t sum = 0;
33
+ #pragma unroll
34
+ for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) {
35
+ uint32_t x = num_segs[k];
36
+ #pragma unroll
37
+ for (uint32_t offset = 1; offset < 32; offset <<= 1) {
38
+ const uint32_t& y = __shfl_up_sync(0xffffffff, x, offset);
39
+ x += (lane_idx >= offset ? y : 0);
40
+ }
41
+ x += sum;
42
+ prefix_sum[k * 32 + lane_idx] = x;
43
+ sum = __shfl_sync(0xffffffff, x, 31);
44
+ }
45
+
46
+ const uint32_t& q = sum / kNumSMs, r = sum % kNumSMs;
47
+ for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) {
48
+ uint32_t seg_starts = sm_idx * q + min(sm_idx, r);
49
+ uint32_t q_idx = 0;
50
+ while (q_idx < batch_size and prefix_sum[q_idx] <= seg_starts)
51
+ ++ q_idx;
52
+ const uint32_t& kv_split_idx = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1]);
53
+ __syncwarp();
54
+
55
+ schedule_metadata[sm_idx * 2] = q_idx;
56
+ schedule_metadata[sm_idx * 2 + 1] = kv_split_idx;
57
+ }
58
+ }
59
+
60
+ template <uint32_t kNextN, bool kIsContextLens2D,
61
+ uint32_t BLOCK_KV, uint32_t kNumBlocksPerSplit>
62
+ struct PagedMQALogitsScheduler {
63
+ uint32_t batch_size;
64
+ const uint32_t* context_lens;
65
+
66
+ uint32_t current_q_idx, current_kv_idx;
67
+ uint32_t end_q_idx, end_kv_idx;
68
+ uint32_t current_num_kv;
69
+
70
+ __device__ __forceinline__ uint32_t get_num_kv(const uint32_t& q_idx) {
71
+ const auto& lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx);
72
+ return q_idx < batch_size ? ceil_div(__ldg(context_lens + lens_idx), BLOCK_KV) : 0;
73
+ }
74
+
75
+ __device__ __forceinline__ explicit PagedMQALogitsScheduler(const uint32_t& batch_size, const uint32_t& sm_idx,
76
+ const uint32_t* context_lens, const uint32_t* schedule_meta) {
77
+ this->batch_size = batch_size;
78
+ this->context_lens = context_lens;
79
+
80
+ const auto& current_pack = __ldg(reinterpret_cast<const uint2*>(schedule_meta) + sm_idx);
81
+ const auto& end_pack = __ldg(reinterpret_cast<const uint2*>(schedule_meta) + sm_idx + 1);
82
+ current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit;
83
+ end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit;
84
+
85
+ current_num_kv = get_num_kv(current_q_idx);
86
+ }
87
+
88
+ __device__ __forceinline__ bool fetch_next_task(uint32_t &q_idx, uint32_t &kv_idx, uint32_t &num_kv) {
89
+ q_idx = current_q_idx;
90
+ kv_idx = current_kv_idx;
91
+ num_kv = current_num_kv;
92
+
93
+ if (q_idx == end_q_idx and kv_idx == end_kv_idx)
94
+ return false;
95
+
96
+ current_kv_idx += kNumBlocksPerSplit;
97
+ if (current_kv_idx >= current_num_kv) {
98
+ ++ current_q_idx;
99
+ current_kv_idx = 0;
100
+ current_num_kv = get_num_kv(current_q_idx);
101
+ }
102
+
103
+ return true;
104
+ }
105
+
106
+ __device__ __forceinline__ bool exist_q_idx(const uint32_t& q_idx) const {
107
+ return q_idx < end_q_idx or q_idx == end_q_idx and 0 < end_kv_idx;
108
+ }
109
+ };
110
+
111
+ using namespace deep_gemm::sm90;
112
+
113
+ template <uint32_t kNextN, uint32_t kNumHeads,
114
+ uint32_t kHeadDim, uint32_t BLOCK_KV,
115
+ bool kIsContextLens2D,
116
+ uint32_t kNumQStages, uint32_t kNumKVStages,
117
+ uint32_t SPLIT_KV,
118
+ uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
119
+ __global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
120
+ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
121
+ const uint64_t logits_stride, const uint64_t block_table_stride,
122
+ const uint32_t* context_lens, float* logits,
123
+ const uint32_t* block_table, const uint32_t* schedule_meta,
124
+ const __grid_constant__ cute::TmaDescriptor tensor_map_q,
125
+ const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
126
+ const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
127
+ const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
128
+ // Types
129
+ using WGMMA = typename FP8MMASelector<kNextN * kNumHeads>::type;
130
+ using Barrier = cutlass::arch::ClusterTransactionBarrier;
131
+
132
+ // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
133
+ const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
134
+ const auto& warpgroup_idx = warp_idx / 4;
135
+ const auto& lane_idx = get_lane_idx();
136
+
137
+ // Prefetch TMA descriptors
138
+ static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128;
139
+ DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
140
+ DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumMathWarpGroups, "Invalid `SPLIT_KV`");
141
+ if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
142
+ cute::prefetch_tma_descriptor(&tensor_map_q);
143
+ cute::prefetch_tma_descriptor(&tensor_map_kv);
144
+ cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
145
+ cute::prefetch_tma_descriptor(&tensor_map_weights);
146
+ }
147
+ __syncwarp();
148
+
149
+ // Shared memory configs
150
+ static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
151
+ static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
152
+ static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float);
153
+ static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment);
154
+ static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) +
155
+ constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment);
156
+
157
+ static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
158
+ static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
159
+ static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment);
160
+ static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) +
161
+ constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment);
162
+
163
+ // Align to swizzling alignment bytes
164
+ extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
165
+ DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
166
+ DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
167
+
168
+ // Q data and barriers on shared memory
169
+ auto smem_q = PatternVisitor([&](const uint32_t& i) {
170
+ return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i);
171
+ });
172
+ auto smem_weights = PatternVisitor([&](const uint32_t& i) {
173
+ return reinterpret_cast<float*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i);
174
+ });
175
+ auto q_barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
176
+ auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; });
177
+ auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); });
178
+
179
+ // Separate math warpgroups and tma load warps into KV groups
180
+ // Each math warpgroup corresponds to a tma load warp
181
+ const auto& kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0);
182
+
183
+ // Per group KV data and barriers on shared memory
184
+ const auto& smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx;
185
+ auto smem_kv = PatternVisitor([&](const uint32_t& i) {
186
+ return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * i);
187
+ });
188
+ auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
189
+ return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i);
190
+ });
191
+ auto kv_barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
192
+ auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; });
193
+ auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; });
194
+
195
+ // Initialize barriers
196
+ if (warp_idx >= kNumMathThreads / 32 and cute::elect_one_sync()) {
197
+ if (kv_group_idx == 0) {
198
+ #pragma unroll
199
+ for (uint32_t i = 0; i < kNumQStages; ++ i) {
200
+ full_q_barriers[i]->init(1);
201
+ empty_q_barriers[i]->init(kNumMathThreads);
202
+ }
203
+ }
204
+ if (kv_group_idx < kNumMathWarpGroups) {
205
+ #pragma unroll
206
+ for (uint32_t i = 0; i < kNumKVStages; ++ i) {
207
+ full_kv_barriers[i]->init(1);
208
+ empty_kv_barriers[i]->init(128);
209
+ }
210
+ }
211
+
212
+ // Make initialized barrier visible in async proxy
213
+ cutlass::arch::fence_barrier_init();
214
+ }
215
+ __syncthreads();
216
+
217
+ // Register reconfigurations
218
+ constexpr uint32_t kNumTMARegisters = 64;
219
+ constexpr uint32_t kNumMathRegisters = 104;
220
+
221
+ // Scheduler
222
+ auto scheduler = PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumMathWarpGroups>(batch_size, blockIdx.x, context_lens, schedule_meta);
223
+ DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV");
224
+
225
+ // Q and KV pipeline
226
+ const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
227
+ return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase
228
+ };
229
+ const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
230
+ return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase
231
+ };
232
+ uint32_t q_iter_idx = 0, kv_iter_idx = 0;
233
+
234
+ if (warp_idx >= kNumMathThreads / 32) {
235
+ // TMA warp-group for loading data
236
+ cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
237
+ if (kv_group_idx >= kNumMathWarpGroups)
238
+ return;
239
+
240
+ const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) {
241
+ if (kv_group_idx == 0 and cute::elect_one_sync()) {
242
+ tma_copy<kHeadDim, kNextN * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads);
243
+ tma_copy<kNextN * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx);
244
+ full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
245
+ }
246
+ };
247
+
248
+ // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
249
+ uint32_t q_idx = batch_size, kv_idx, num_kv;
250
+ uint32_t next_q_idx, next_kv_idx, next_num_kv;
251
+ bool fetched_next_task;
252
+
253
+ // Prefetch the first Q
254
+ if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)))
255
+ issue_tma_q(0, next_q_idx), q_iter_idx = 1;
256
+
257
+ int kv_block_idx_ptr = 32;
258
+ uint32_t kv_block_idx_storage;
259
+
260
+ while (fetched_next_task) {
261
+ // Prefetch next Q when current Q changes
262
+ bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1));
263
+ q_idx = next_q_idx;
264
+ kv_idx = next_kv_idx;
265
+ num_kv = next_num_kv;
266
+
267
+ // Wait Q consumer release and issue TMA Q
268
+ if (prefetch_q) {
269
+ CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
270
+ empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
271
+ issue_tma_q(q_stage_idx, q_idx + 1);
272
+ }
273
+
274
+ // Read KV block index
275
+ // TODO: deal with `-1`?
276
+ if (kv_idx == 0 or kv_block_idx_ptr == 32) {
277
+ kv_block_idx_ptr = 0;
278
+ kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ?
279
+ __ldg(block_table + q_idx * block_table_stride + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)) : 0);
280
+ }
281
+ const auto& kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++);
282
+
283
+ // Wait KV consumer release
284
+ CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
285
+ empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
286
+
287
+ // Issue TMA KV
288
+ if (cute::elect_one_sync()) {
289
+ tma_copy<kHeadDim, BLOCK_KV, 0, __nv_fp8_e4m3, true>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
290
+ smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx);
291
+ tma_copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
292
+ smem_kv_scales[kv_stage_idx], 0, kv_block_idx);
293
+ full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
294
+ }
295
+
296
+ // Fetch next task
297
+ fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv);
298
+ }
299
+ } else {
300
+ // Math warp-groups for WGMMA
301
+ cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
302
+
303
+ float accum[WGMMA::kNumAccum], weights[kNextN][kNumHeads / 4];
304
+ const auto& sub_warp_offset = (warp_idx % 4) * 16;
305
+ const auto& v_0_offset = lane_idx / 4 + 0;
306
+ const auto& v_1_offset = lane_idx / 4 + 8;
307
+
308
+ // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
309
+ uint32_t q_idx = batch_size, kv_idx;
310
+ uint32_t next_q_idx, next_kv_idx, next_num_kv;
311
+ uint32_t q_stage_idx, q_phase;
312
+
313
+ while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) {
314
+ // Current Q changes
315
+ if (q_idx != next_q_idx) {
316
+ // Release Last Q empty
317
+ if (q_iter_idx > 0)
318
+ empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive();
319
+
320
+ // Wait TMA Q arrival
321
+ CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
322
+ full_q_barriers[q_stage_idx]->wait(q_phase);
323
+
324
+ // Read weights
325
+ #pragma unroll
326
+ for (uint32_t i = 0; i < kNextN; ++ i) {
327
+ #pragma unroll
328
+ for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
329
+ weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
330
+ }
331
+ }
332
+
333
+ // Get current Q and KV index
334
+ q_idx = next_q_idx;
335
+ kv_idx = next_kv_idx;
336
+
337
+ // Calculate KV offset in advance
338
+ auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset);
339
+
340
+ // Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]`
341
+ // Wait TMA KV arrival
342
+ CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
343
+ full_kv_barriers[kv_stage_idx]->wait(kv_phase);
344
+
345
+ // Issue WGMMA
346
+ DG_STATIC_ASSERT(BLOCK_KV == 64, "Invalid block size");
347
+ DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim");
348
+ #pragma unroll
349
+ for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
350
+ warpgroup_fence_operand(accum[i]);
351
+ warpgroup_arrive();
352
+ #pragma unroll
353
+ for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) {
354
+ auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + k * WGMMA::K, to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
355
+ auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
356
+ WGMMA::wgmma(desc_a, desc_b, accum, k);
357
+ }
358
+ warpgroup_commit_batch();
359
+ #pragma unroll
360
+ for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
361
+ warpgroup_fence_operand(accum[i]);
362
+
363
+ // Read per-KV scales
364
+ float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset);
365
+ float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset);
366
+
367
+ // Wait WGMMA
368
+ warpgroup_wait<0>();
369
+
370
+ // Release KV empty
371
+ empty_kv_barriers[kv_stage_idx]->arrive();
372
+
373
+ // Reduce over the head dim and store
374
+ static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2;
375
+ DG_STATIC_ASSERT(WGMMA::kNumAccum % kNumAccumPerReduce == 0, "Invalid accumulation");
376
+ DG_STATIC_ASSERT(WGMMA::kNumAccum / kNumAccumPerReduce == kNextN, "Invalid accumulation");
377
+ DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
378
+ #pragma unroll
379
+ for (uint32_t i = 0; i < kNextN; ++ i) {
380
+ auto shifted_accum = accum + i * kNumAccumPerReduce;
381
+ const auto& transform = [&](const uint32_t& j) {
382
+ return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)];
383
+ };
384
+
385
+ // Intra-thread reduction
386
+ float sum[4] = {transform(0), transform(1), transform(2), transform(3)};
387
+ #pragma unroll
388
+ for (uint32_t j = 1; j < kNumHeads / 8; ++ j) {
389
+ #pragma unroll
390
+ for (uint32_t k = 0; k < 4; k ++)
391
+ sum[k] += transform(j * 4 + k);
392
+ }
393
+ float v_0 = (sum[0] + sum[1]) * scale_kv_0;
394
+ float v_1 = (sum[2] + sum[3]) * scale_kv_1;
395
+
396
+ // Inter-thread reduction
397
+ #pragma unroll
398
+ for (uint32_t j = 0; j < 2; ++ j) {
399
+ const auto& offset = static_cast<int>(1u << j);
400
+ v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset);
401
+ v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset);
402
+ }
403
+
404
+ // Store into the global memory
405
+ // NOTES: we have redundant writes here, consider more carefully
406
+ logits[kv_offset + i * logits_stride + v_0_offset] = v_0;
407
+ logits[kv_offset + i * logits_stride + v_1_offset] = v_1;
408
+ }
409
+ }
410
+ }
411
+ }
412
+
413
+ } // namespace deep_gemm
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #pragma clang diagnostic push
3
+ #pragma clang diagnostic ignored "-Wunknown-attributes"
4
+
5
+ #include <cutlass/arch/barrier.h>
6
+ #include <cutlass/arch/reg_reconfig.h>
7
+
8
+ #include <deep_gemm/common/reduction.cuh>
9
+ #include <deep_gemm/common/utils.cuh>
10
+ #include <deep_gemm/common/sm90_utils.cuh>
11
+
12
+ namespace deep_gemm {
13
+
14
+ using namespace deep_gemm::sm90;
15
+
16
+ template <uint32_t kSwizzleMode, uint32_t kSwizzleBase = 16>
17
+ __device__ __forceinline__
18
+ uint32_t get_swizzled_bank_group_idx(const uint32_t& offset, const uint32_t& lane_idx) {
19
+ constexpr uint32_t kGroupsInSwizzleRange = kSwizzleMode / kSwizzleBase;
20
+
21
+ const auto& bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange;
22
+
23
+ constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase;
24
+ constexpr bool kHasShortcut = kGroupsInSwizzleRange == kNumBankGroups;
25
+ auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups);
26
+ auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups);
27
+ col ^= row % kGroupsInSwizzleRange;
28
+
29
+ return (row * kNumBankGroups + col) % kGroupsInSwizzleRange;
30
+ }
31
+
32
+ template <uint32_t SHAPE_N, uint32_t SHAPE_K,
33
+ uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
34
+ uint32_t kNumSplits,
35
+ uint32_t kSwizzleCDMode,
36
+ uint32_t kNumStages,
37
+ uint32_t kNumMathThreads, uint32_t kNumTMAThreads>
38
+ __global__ void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1)
39
+ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
40
+ const __grid_constant__ cute::TmaDescriptor tensor_map_a,
41
+ const __grid_constant__ cute::TmaDescriptor tensor_map_b,
42
+ const __grid_constant__ cute::TmaDescriptor tensor_map_d,
43
+ float* sqr_sum) {
44
+ #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
45
+ using Barrier = cutlass::arch::ClusterTransactionBarrier;
46
+
47
+ // kSwizzleAMode and kSwizzleBMode must be 128 for now
48
+ constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128);
49
+ constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128);
50
+ DG_STATIC_ASSERT(BLOCK_K == 64, "Invalid block K");
51
+ DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode");
52
+ DG_STATIC_ASSERT(kSwizzleBMode == 128, "Invalid swizzle B mode");
53
+
54
+ DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N");
55
+ DG_STATIC_ASSERT(kNumMathThreads == 128, "Invalid MMA threads");
56
+
57
+ // Utils
58
+ const auto warp_idx = cutlass::canonical_warp_idx_sync();
59
+ const auto lane_idx = get_lane_idx();
60
+
61
+ // Align to 1024 bytes for swizzle-128B
62
+ extern __shared__ __align__(1024) uint8_t smem_buffer[];
63
+
64
+ // Share memory sizes
65
+ constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode;
66
+ constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16);
67
+ constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float);
68
+ DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
69
+
70
+ if (warp_idx == 0 and cute::elect_one_sync()) {
71
+ cute::prefetch_tma_descriptor(&tensor_map_a);
72
+ cute::prefetch_tma_descriptor(&tensor_map_b);
73
+ cute::prefetch_tma_descriptor(&tensor_map_d);
74
+ }
75
+
76
+ // Data on shared memory (layout as ordered below)
77
+ // Fill D/A/B pointers
78
+ auto smem_cd = reinterpret_cast<float*>(smem_buffer);
79
+ auto smem_a = PatternVisitor([&](const uint32_t& i) {
80
+ return reinterpret_cast<nv_bfloat16*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
81
+ });
82
+ auto smem_b = PatternVisitor([&](const uint32_t& i) {
83
+ return reinterpret_cast<float*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
84
+ });
85
+
86
+ // Fill barriers
87
+ auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
88
+ auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
89
+ auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
90
+
91
+ // Initialize barriers
92
+ if (warp_idx == 1 and cute::elect_one_sync()) {
93
+ #pragma unroll
94
+ for (uint32_t i = 0; i < kNumStages; ++ i) {
95
+ full_barriers[i]->init(1);
96
+ empty_barriers[i]->init(128);
97
+ }
98
+
99
+ // Make initialized barrier visible in async proxy
100
+ cutlass::arch::fence_barrier_init();
101
+ }
102
+ __syncthreads();
103
+
104
+ constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K);
105
+ constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits;
106
+ constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits;
107
+ const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0);
108
+ const uint32_t m_block_idx = block_idx / kNumSplits;
109
+ const uint32_t k_split_idx = block_idx % kNumSplits;
110
+ const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K;
111
+ const uint32_t m_offset = shape_m * k_split_idx;
112
+ const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks);
113
+ constexpr uint32_t kNumTMARegisters = 40;
114
+ constexpr uint32_t kNumMathRegisters = 256;
115
+
116
+ // TMA load warp
117
+ if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
118
+ cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
119
+ for (uint32_t s = 0; s < num_total_stages; ++ s) {
120
+ // Wait consumer release
121
+ const auto& stage_idx = s % kNumStages;
122
+ empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
123
+
124
+ // Compute offsets
125
+ uint32_t m_idx = m_block_idx * BLOCK_M;
126
+ uint32_t k_idx = k_offset + s * BLOCK_K;
127
+
128
+ // Issue TMAs
129
+ tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
130
+ tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0);
131
+
132
+ // Arrive at full barriers
133
+ constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
134
+ full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes);
135
+ }
136
+
137
+ for (uint32_t s = num_total_stages; s < num_total_stages + kNumStages; ++ s) {
138
+ const auto& stage_idx = s % kNumStages;
139
+ empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
140
+ }
141
+ } else if (warp_idx < kNumMathThreads / 32) {
142
+ cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
143
+
144
+ DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M");
145
+ DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K");
146
+ constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4;
147
+ constexpr uint32_t WGMMA_M = 64;
148
+ constexpr uint32_t WGMMA_N = BLOCK_N;
149
+ constexpr uint32_t WGMMA_K = 8;
150
+
151
+ using WGMMA = typename TF32MMASelector<WGMMA_N, true>::type;
152
+ float accum[WGMMA::kNumAccum] = {0};
153
+
154
+ constexpr uint32_t kNumBankGroupBytes = 16;
155
+ constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16);
156
+ constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup;
157
+ float sqr_sum_acc_0 = 0;
158
+ float sqr_sum_acc_1 = 0;
159
+
160
+ #pragma unroll kNumStages < 8 ? kNumStages : kNumStages / 2
161
+ for (uint32_t s = 0; s < num_total_stages; ++ s) {
162
+ // Wait TMA arrival
163
+ const auto& stage_idx = s % kNumStages;
164
+ full_barriers[stage_idx]->wait((s / kNumStages) & 1);
165
+
166
+ constexpr uint32_t kNumRegPerWgmma = WGMMA::M * WGMMA::K / 128;
167
+ constexpr uint32_t kNumWgmmaPerBlockK = BLOCK_K / WGMMA::K;
168
+
169
+ float a[kNumRegPerWgmma * kNumWgmmaPerBlockK];
170
+ // Assume swizzle A mode is 128
171
+ DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode");
172
+
173
+ // Load BF16 A fragment from shared memory into registers, and transpose to FP32
174
+ uint32_t row = warp_idx * 16 + lane_idx / 4;
175
+ #pragma unroll
176
+ for (uint32_t i = 0; i < kNumLoads; ++ i) {
177
+ // Refer to the A layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-a
178
+ uint32_t bank_group_idx = (row ^ i) % 8;
179
+ nv_bfloat16* a_bf16_smem_ptr_upper = smem_a[stage_idx] + row * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup;
180
+ nv_bfloat16* a_bf16_smem_ptr_lower = smem_a[stage_idx] + (row + 8) * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup;
181
+
182
+ uint32_t elem_offset = lane_idx % 4;
183
+ nv_bfloat16 a_bf16[kNumRegPerWgmma];
184
+ a_bf16[0] = a_bf16_smem_ptr_upper[elem_offset];
185
+ a_bf16[2] = a_bf16_smem_ptr_upper[elem_offset + 4];
186
+ a_bf16[1] = a_bf16_smem_ptr_lower[elem_offset];
187
+ a_bf16[3] = a_bf16_smem_ptr_lower[elem_offset + 4];
188
+
189
+ auto a_bf16x2_ptr = reinterpret_cast<nv_bfloat162*>(a_bf16);
190
+ auto a_float2_ptr = reinterpret_cast<float2*>(a);
191
+ float2 a_float2_0 = __bfloat1622float2(a_bf16x2_ptr[0]);
192
+ float2 a_float2_1 = __bfloat1622float2(a_bf16x2_ptr[1]);
193
+ a_float2_ptr[i * 2 + 0] = a_float2_0;
194
+ a_float2_ptr[i * 2 + 1] = a_float2_1;
195
+ sqr_sum_acc_0 += a_float2_0.x * a_float2_0.x + a_float2_1.x * a_float2_1.x;
196
+ sqr_sum_acc_1 += a_float2_0.y * a_float2_0.y + a_float2_1.y * a_float2_1.y;
197
+ }
198
+
199
+ warpgroup_wait<0>();
200
+ if (s > 0)
201
+ empty_barriers[(s - 1) % kNumStages]->arrive();
202
+
203
+ #pragma unroll
204
+ for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
205
+ warpgroup_fence_operand(accum[i]);
206
+ warpgroup_arrive();
207
+
208
+ constexpr int kNumElemsInSwizzleRange = 128 / sizeof(float);
209
+ constexpr uint32_t kNumWgmmaInSwizzleRange = kNumElemsInSwizzleRange / WGMMA::K;
210
+ DG_STATIC_ASSERT(BLOCK_K % kNumElemsInSwizzleRange == 0, "Invalid block K");
211
+
212
+ #pragma unroll
213
+ for (int i = 0; i < BLOCK_K / kNumElemsInSwizzleRange; i++) {
214
+ #pragma unroll
215
+ for (int k = 0; k < kNumElemsInSwizzleRange / WGMMA::K; k++) {
216
+ auto b_desc = make_smem_desc(smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1);
217
+ WGMMA::wgmma(a + (i * kNumWgmmaInSwizzleRange + k) * kNumRegPerWgmma, b_desc, accum, 1);
218
+ }
219
+ }
220
+ warpgroup_commit_batch();
221
+ #pragma unroll
222
+ for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
223
+ warpgroup_fence_operand(accum[i]);
224
+ }
225
+
226
+ const auto& reduced_sum_0 = warp_reduce_sum<4>(sqr_sum_acc_0);
227
+ const auto& reduced_sum_1 = warp_reduce_sum<4>(sqr_sum_acc_1);
228
+
229
+ const auto& m_idx = m_block_idx * BLOCK_M + (warp_idx * BLOCK_M_PER_WARP + lane_idx / 4);
230
+ if (lane_idx % 4 == 0) {
231
+ if (m_idx < shape_m)
232
+ sqr_sum[m_offset + m_idx] = reduced_sum_0;
233
+ if (m_idx + 8 < shape_m)
234
+ sqr_sum[m_offset + m_idx + 8] = reduced_sum_1;
235
+ }
236
+ warpgroup_wait<0>();
237
+ empty_barriers[(num_total_stages-1) % kNumStages]->arrive();
238
+
239
+ // Write accum to shared memory
240
+ // Every 2 threads (one pair) will write to the same bank group (16 bytes).
241
+ // Refer to the D layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-d
242
+ uint32_t is_odd_pair = lane_idx / 2 % 2;
243
+
244
+ // Four threads per group; write the data to the same row.
245
+ uint32_t row_idx = lane_idx / 4;
246
+
247
+ // Even/odd index pairs write to the same column, we need to reorder idx:
248
+ // group even pair indices consecutively, and likewise for odd ones.
249
+ uint32_t reordered_pair_idx = is_odd_pair * 8 + row_idx;
250
+
251
+ auto shifted_smem_ptr = reinterpret_cast<uint8_t*>(smem_cd) +
252
+ (warp_idx * BLOCK_M_PER_WARP + row_idx) * kSwizzleCDMode + // Row offset, each warp has 16 rows
253
+ lane_idx % 2 * 8; // One thread of a pair writes 8 bytes
254
+
255
+ #pragma unroll
256
+ for (uint32_t i = 0; i < (kSwizzleCDMode / sizeof(float)) / 4; i += 2) {
257
+ // Get the swizzled bank group index (16 bytes per group)
258
+ uint32_t bank_group_idx = get_swizzled_bank_group_idx<kSwizzleCDMode>(i + is_odd_pair, reordered_pair_idx);
259
+ auto smem_ptr = shifted_smem_ptr + bank_group_idx * kNumBankGroupBytes; // Col offset, 16 bytes per group
260
+
261
+ // 0/1 write to the same row, 2/3 write to another row
262
+ auto values = reinterpret_cast<uint32_t*>(accum + i * 2);
263
+ st_shared(smem_ptr, values[0], values[1]);
264
+ st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]);
265
+ }
266
+ cute::tma_store_fence();
267
+ cutlass::arch::NamedBarrier::sync(128, 1);
268
+
269
+ // Issue TMA stores
270
+ if (warp_idx == 0 and cute::elect_one_sync()) {
271
+ if constexpr (kNumSplits == 1) {
272
+ cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M);
273
+ } else {
274
+ cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx);
275
+ }
276
+ cute::tma_store_arrive();
277
+ }
278
+ }
279
+ #else
280
+ if (blockIdx.x == 0 and threadIdx.x == 0)
281
+ DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
282
+ #endif
283
+ }
284
+
285
+ } // namespace deep_gemm
286
+
287
+ #pragma clang diagnostic pop
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/smxx_clean_logits.cuh ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cutlass/arch/barrier.h>
4
+ #include <cute/arch/cluster_sm90.hpp>
5
+
6
+ #include <deep_gemm/common/utils.cuh>
7
+
8
+ namespace deep_gemm {
9
+
10
+ template <uint32_t kNextN, uint32_t BLOCK_KV, uint32_t kNumWarps>
11
+ __global__ __launch_bounds__(kNumWarps * 32, 1)
12
+ void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_logits,
13
+ const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, float* logits) {
14
+ const uint32_t& num_sms = gridDim.x;
15
+ const uint32_t& sm_idx = blockIdx.x;
16
+ const uint32_t& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
17
+ constexpr float neg_inf = -cute::numeric_limits<float>::infinity();
18
+
19
+ // Allocate filled `-inf` shared memory
20
+ extern __shared__ __align__(1024) float smem_buffer[];
21
+ #pragma unroll
22
+ for (uint32_t i = threadIdx.x; i < BLOCK_KV; i += kNumWarps * 32)
23
+ smem_buffer[i] = neg_inf;
24
+ cute::tma_store_fence();
25
+ __syncthreads();
26
+
27
+ // Assign sequence to each warp
28
+ const auto& assign_task = [&](const uint32_t& num, const uint32_t& idx,
29
+ const uint32_t& start, const uint32_t& total) -> cute::tuple<uint32_t, uint32_t> {
30
+ const auto& per = total / num, rem = total % num;
31
+ return {start + idx * per + min(idx, rem), per + (idx < rem)};
32
+ };
33
+ CUTE_TIE_DECL(assign_task(num_sms, sm_idx, 0, seq_len), sm_seq_start, sm_seq_len);
34
+ CUTE_TIE_DECL(assign_task(kNumWarps, warp_idx, sm_seq_start, sm_seq_len), warp_seq_start, warp_seq_len);
35
+
36
+ if (cute::elect_one_sync()) {
37
+ for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) {
38
+ const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN);
39
+ const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1;
40
+ const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4;
41
+
42
+ for (uint32_t left = 0; left < seq_len_kv; left += BLOCK_KV) {
43
+ const auto& right = min(left + BLOCK_KV, static_cast<uint32_t>(stride_logits));
44
+ if (right <= ks or ke <= left) {
45
+ cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(float));
46
+ } else {
47
+ if (left < aligned_ks)
48
+ cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(float));
49
+ if (aligned_ke < right)
50
+ cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(float));
51
+ }
52
+ }
53
+ }
54
+ }
55
+
56
+ for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) {
57
+ const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN);
58
+ const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1;
59
+ const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4;
60
+ for (uint32_t j = aligned_ks; j < ks; ++ j)
61
+ logits[i * stride_logits + j] = neg_inf;
62
+ for (uint32_t j = ke; j < aligned_ke; ++ j)
63
+ logits[i * stride_logits + j] = neg_inf;
64
+ }
65
+ }
66
+
67
+ }
build/torch29-cxx11-cu129-x86_64-linux/include/deep_gemm/impls/smxx_layout.cuh ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <deep_gemm/common/utils.cuh>
4
+
5
+ namespace deep_gemm {
6
+
7
+ template <uint32_t kNumThreads, uint32_t BLOCK_MN, uint32_t SF_K,
8
+ uint32_t PADDED_SF_K = SF_K + (1 - (SF_K % 2))>
9
+ __global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) {
10
+ typedef typename Vectorized<sizeof(float) * SF_K>::vec_t in_vec_t;
11
+ constexpr static uint32_t kNumElemsPerVec = sizeof(in_vec_t) / sizeof(float);
12
+ constexpr static uint32_t SF_VEC_K = SF_K / kNumElemsPerVec;
13
+
14
+ // Shapes and strides
15
+ extern __shared__ float smem_buffer[];
16
+ constexpr auto kNumTMAAlignedElems = static_cast<uint32_t>(16 / sizeof(float));
17
+ const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
18
+ const auto tma_aligned_mn = align<uint32_t>(mn, kNumTMAAlignedElems);
19
+
20
+ // Shift into the block
21
+ sf = sf + static_cast<uint64_t>(blockIdx.y) * mn * SF_K;
22
+ out = out + static_cast<uint64_t>(blockIdx.y) * tma_aligned_mn * SF_K;
23
+ const auto& local_sf = reinterpret_cast<const in_vec_t*>(sf + static_cast<uint64_t>(blockIdx.x) * (BLOCK_MN * SF_K));
24
+
25
+ // Load
26
+ for (uint32_t i = threadIdx.x; i < in_block_mn * SF_VEC_K; i += kNumThreads) {
27
+ auto in_vec = __ldg(local_sf + i);
28
+ const auto& in_values = reinterpret_cast<float*>(&in_vec);
29
+
30
+ const auto& row = i / SF_VEC_K, col = (i % SF_VEC_K) * kNumElemsPerVec;
31
+ #pragma unroll
32
+ for (uint32_t j = 0; j < kNumElemsPerVec; ++ j)
33
+ smem_buffer[row * PADDED_SF_K + col + j] = in_values[j];
34
+ }
35
+ __syncthreads();
36
+
37
+ // Store
38
+ #pragma unroll
39
+ for (uint32_t i = threadIdx.x; i < in_block_mn * SF_K; i += kNumThreads) {
40
+ const auto& sf_k_idx = i / in_block_mn, mn_idx = i % in_block_mn;
41
+ const auto& global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx;
42
+ out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx);
43
+ }
44
+ }
45
+
46
+ // NOTES: the two kernels below always pack the K dimension
47
+
48
+ template <uint32_t kNumThreads, uint32_t BLOCK_MN, uint32_t SF_K>
49
+ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) {
50
+ extern __shared__ uint32_t smem_buffer[];
51
+
52
+ // Shapes and strides
53
+ constexpr auto kNumPackedSFK = constexpr_ceil_div(SF_K, 4u);
54
+ constexpr auto kNumTMAAlignedElems = static_cast<uint32_t>(16 / sizeof(int));
55
+ const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
56
+ const auto tma_aligned_mn = align<uint64_t>(mn, kNumTMAAlignedElems);
57
+
58
+ // Shift into the group
59
+ sf = sf + static_cast<uint64_t>(blockIdx.y) * mn * SF_K;
60
+ out = out + static_cast<uint64_t>(blockIdx.y) * tma_aligned_mn * kNumPackedSFK;
61
+
62
+ // Load FP32 SFs
63
+ DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block size");
64
+ const auto local_sf = reinterpret_cast<uint32_t*>(sf + static_cast<uint64_t>(blockIdx.x) * (BLOCK_MN * SF_K));
65
+ const auto num_values = in_block_mn * SF_K;
66
+ const auto num_uint4 = num_values / 4;
67
+ #pragma unroll
68
+ for (uint32_t i = threadIdx.x; i < num_uint4; i += kNumThreads) {
69
+ const auto& [x, y, z, w] = __ldg(reinterpret_cast<uint4*>(local_sf) + i);
70
+ st_shared(reinterpret_cast<uint4*>(smem_buffer) + i, x, y, z, w);
71
+ }
72
+
73
+ // Fill unaligned values as well
74
+ if (const auto unaligned_idx = num_uint4 * 4 + threadIdx.x; unaligned_idx < num_values)
75
+ st_shared(smem_buffer + unaligned_idx, __ldg(local_sf + unaligned_idx));
76
+ __syncthreads();
77
+
78
+ // Pack into UE8M0 and store
79
+ #pragma unroll
80
+ for (uint32_t i = threadIdx.x; i < (kNumPackedSFK * BLOCK_MN); i += kNumThreads) {
81
+ const auto sf_k_pack_idx = i / BLOCK_MN, mn_idx = i % BLOCK_MN;
82
+
83
+ // Load shared memory
84
+ uint32_t values[4];
85
+ #pragma unroll
86
+ for (uint32_t j = 0; j < 4; ++ j) {
87
+ const auto sf_k_idx = sf_k_pack_idx * 4 + j;
88
+ values[j] = sf_k_idx < SF_K ? ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0;
89
+ }
90
+
91
+ // Pack and store
92
+ uint32_t packed = 0;
93
+ packed |= (values[0] >> 23u);
94
+ packed |= (values[1] >> 15u);
95
+ packed |= (values[2] >> 7u);
96
+ packed |= (values[3] << 1u);
97
+ if (const auto global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; global_mn_idx < mn)
98
+ out[sf_k_pack_idx * tma_aligned_mn + global_mn_idx] = packed;
99
+ }
100
+ }
101
+
102
+ template <uint32_t kNumGroups, uint32_t kNumThreads,
103
+ uint32_t BLOCK_MN, uint32_t BLOCK_PACKED_SF_K, bool kTransposed = true>
104
+ __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks,
105
+ const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k) {
106
+ // Always packing the K dimension
107
+ // NOTES: should also assert `mn % 4 == 0` at launch
108
+ DG_STATIC_ASSERT(kTransposed, "Currently only support transposed SFs (MN-major)");
109
+ DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block sizes");
110
+ DG_STATIC_ASSERT(BLOCK_PACKED_SF_K == kNumThreads / 32, "Invalid block sizes");
111
+
112
+ // Shapes and strides
113
+ const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
114
+ const auto in_block_mn_uint4 = in_block_mn / 4;
115
+ const auto in_block_packed_sf_k = min(BLOCK_PACKED_SF_K, packed_sf_k - blockIdx.y * BLOCK_PACKED_SF_K);
116
+
117
+ // Shift into the right block along MN
118
+ sf += blockIdx.x * BLOCK_MN;
119
+ out += blockIdx.x * BLOCK_MN;
120
+
121
+ // Each warp is responsible for a packed row
122
+ const auto warp_idx = threadIdx.x / 32;
123
+ const auto lane_idx = get_lane_idx();
124
+ const auto packed_sf_k_idx = static_cast<uint64_t>(blockIdx.y) * BLOCK_PACKED_SF_K + warp_idx;
125
+ if (warp_idx >= in_block_packed_sf_k)
126
+ return;
127
+
128
+ // Make an offset on the input
129
+ uint32_t input_offset = 0;
130
+ if constexpr (kNumGroups > 1) {
131
+ // Load each group's size
132
+ DG_STATIC_ASSERT(kNumGroups <= 128, "Too many groups");
133
+ uint32_t group_ks[4];
134
+ #pragma unroll
135
+ for (uint32_t i = 0; i < 4; ++ i) {
136
+ const auto group_idx = lane_idx * 4 + i;
137
+ group_ks[i] = group_idx < kNumGroups ? __ldg(ks + group_idx) : 0;
138
+ }
139
+ __syncwarp();
140
+
141
+ // Make the offset
142
+ sf_k = 0;
143
+ auto sum_packed_sf_k = 0;
144
+ #pragma unroll
145
+ for (uint32_t i = 0; i < kNumGroups; ++ i) {
146
+ const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / 128, i / 4);
147
+ sf_k += sf_k_in_group;
148
+ sum_packed_sf_k += ceil_div(sf_k_in_group, 4u);
149
+ if (packed_sf_k_idx < sum_packed_sf_k)
150
+ break;
151
+ if (const auto remainder = sf_k_in_group % 4; remainder > 0)
152
+ input_offset += 4 - remainder;
153
+ }
154
+ }
155
+
156
+ for (uint32_t mn_idx = get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) {
157
+ // Load
158
+ uint4 values[4];
159
+ #pragma unroll
160
+ for (uint32_t j = 0; j < 4; ++ j) {
161
+ values[j] = make_uint4(0, 0, 0, 0);
162
+ if (const auto sf_k_idx = packed_sf_k_idx * 4 + j - input_offset; sf_k_idx < sf_k)
163
+ values[j] = __ldg(reinterpret_cast<uint4*>(sf + sf_k_idx * mn) + mn_idx);
164
+ }
165
+
166
+ // Pack and store
167
+ uint4 packed;
168
+ packed.x = (values[0].x >> 23u) | (values[1].x >> 15u) | (values[2].x >> 7u) | (values[3].x << 1u);
169
+ packed.y = (values[0].y >> 23u) | (values[1].y >> 15u) | (values[2].y >> 7u) | (values[3].y << 1u);
170
+ packed.z = (values[0].z >> 23u) | (values[1].z >> 15u) | (values[2].z >> 7u) | (values[3].z << 1u);
171
+ packed.w = (values[0].w >> 23u) | (values[1].w >> 15u) | (values[2].w >> 7u) | (values[3].w << 1u);
172
+ reinterpret_cast<uint4*>(out + packed_sf_k_idx * mn)[mn_idx] = packed;
173
+ }
174
+ }
175
+
176
+ } // namespace deep_gemm
build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/03_visualize_layout/options.h ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ #pragma once
33
+
34
+ #include <vector>
35
+ #include <iostream>
36
+
37
+ // Cutlass command line parser
38
+ #include "cutlass/util/command_line.h"
39
+
40
+ class Options {
41
+ public:
42
+
43
+ bool help;
44
+ bool good;
45
+ std::vector<int> extent; ///< extent of tile to fill
46
+ std::vector<int> stride; ///< stride vector for layout function
47
+ std::vector<int> output_shape; ///< output shape
48
+ int vectorize; ///< sequences of consecutive output elements are concatenated into a vector
49
+ /// if, and only if, they were consecutive in source memory
50
+
51
+ public:
52
+
53
+ /// Options
54
+ Options():
55
+ help(false),
56
+ good(true),
57
+ extent({32, 8}),
58
+ stride({32}),
59
+ output_shape({16, 8}),
60
+ vectorize(1) {
61
+
62
+ }
63
+
64
+ /// Constructs from command line parser
65
+ Options(cutlass::CommandLine const & cmd_line): help(false), good(true) {
66
+
67
+ if (cmd_line.check_cmd_line_flag("help") ||
68
+ cmd_line.check_cmd_line_flag("h")) {
69
+
70
+ help = true;
71
+ }
72
+
73
+ if (cmd_line.check_cmd_line_flag("extent")) {
74
+ cmd_line.get_cmd_line_arguments("extent", extent);
75
+ }
76
+ else {
77
+ extent = {32, 8};
78
+ }
79
+
80
+ if (cmd_line.check_cmd_line_flag("stride")) {
81
+ cmd_line.get_cmd_line_arguments("stride", stride);
82
+ }
83
+
84
+ int default_output_shape[] = {16, 8};
85
+
86
+ if (cmd_line.check_cmd_line_flag("output-shape")) {
87
+ cmd_line.get_cmd_line_arguments("output-shape", output_shape);
88
+ }
89
+
90
+ for (int i = int(output_shape.size()); i < 2; ++i) {
91
+ output_shape.push_back(default_output_shape[i]);
92
+ }
93
+
94
+ if (cmd_line.check_cmd_line_flag("vectorize")) {
95
+ cmd_line.get_cmd_line_argument("vectorize", vectorize);
96
+ }
97
+ else {
98
+ vectorize = 1;
99
+ }
100
+
101
+ if (output_shape.front() % vectorize) {
102
+
103
+ std::cerr << "Error: --vectorize=" << vectorize
104
+ << " must divide contiguous elements in --output-shape="
105
+ << output_shape.at(0) << "," << output_shape.at(1) << std::endl;
106
+
107
+ good = false;
108
+ }
109
+ }
110
+
111
+ /// Prints usage statement
112
+ static void print_usage(std::ostream &out) {
113
+ out
114
+ << " Options:\n"
115
+ << " --help Displays this help message.\n"
116
+ << " --extent=<extent> Specifies the layout-specific extent (as comma-delimited array).\n"
117
+ << " --stride=<stride> Specifies the layout-specific stride vector (comma-delimited array)\n"
118
+ << " --output-shape=<extent> Specifies the dimensions of a row-major output matrix. \n"
119
+ << " --vectorize=<vector length> If possible, vectorizes the output into vectors of consecutive elements\n";
120
+ }
121
+ };
build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/03_visualize_layout/register_layout.h ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief CUTLASS layout visualization example
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include <map>
39
+ #include <memory>
40
+
41
+ #include "options.h"
42
+
43
+ /////////////////////////////////////////////////////////////////////////////////////////////////
44
+
45
+ struct VisualizeLayoutBase {
46
+ virtual bool visualize(Options const &) = 0;
47
+ virtual bool verify(bool verbose, std::ostream &out) = 0;
48
+ virtual void print_csv(std::ostream &out, char delim = '|', char new_line = '\n') = 0;
49
+ virtual std::ostream &print_help(std::ostream &out) {
50
+ return out;
51
+ }
52
+ virtual ~VisualizeLayoutBase() { }
53
+ };
54
+
55
+ /////////////////////////////////////////////////////////////////////////////////////////////////
56
+
57
+ void RegisterLayouts(std::map<std::string, std::unique_ptr<VisualizeLayoutBase> > &layouts);
58
+
59
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/03_visualize_layout/visualize_layout.h ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief CUTLASS layout visualization example
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include <algorithm>
39
+ #include <stdexcept>
40
+ #include <vector>
41
+
42
+ #include "cutlass/coord.h"
43
+ #include "cutlass/util/reference/host/tensor_foreach.h"
44
+
45
+ #include "register_layout.h"
46
+
47
+ /////////////////////////////////////////////////////////////////////////////////////////////////
48
+
49
+ /// Permits copying dynamic vectors into static-length vectors
50
+ template <typename TensorCoord, int Rank>
51
+ struct vector_to_coord {
52
+
53
+ vector_to_coord(TensorCoord &coord, std::vector<int> const &vec) {
54
+
55
+ coord[Rank - 1] = vec.at(Rank - 1);
56
+
57
+ if (Rank > 1) {
58
+ vector_to_coord<TensorCoord, Rank - 1>(coord, vec);
59
+ }
60
+ }
61
+ };
62
+
63
+ /// Permits copying dynamic vectors into static-length vectors
64
+ template <typename TensorCoord>
65
+ struct vector_to_coord<TensorCoord, 1> {
66
+
67
+ vector_to_coord(TensorCoord &coord, std::vector<int> const &vec) {
68
+
69
+ coord[0] = vec.at(0);
70
+ }
71
+ };
72
+
73
+ /// Permits copying dynamic vectors into static-length vectors
74
+ template <typename TensorCoord>
75
+ struct vector_to_coord<TensorCoord, 0> {
76
+
77
+ vector_to_coord(TensorCoord &coord, std::vector<int> const &vec) {
78
+
79
+ }
80
+ };
81
+
82
+ /////////////////////////////////////////////////////////////////////////////////////////////////
83
+
84
+ template <typename T>
85
+ std::ostream &operator<<(std::ostream &out, std::vector<T> const &vec) {
86
+ auto it = vec.begin();
87
+ if (it != vec.end()) {
88
+ out << *it;
89
+ for (++it; it != vec.end(); ++it) {
90
+ out << ", " << *it;
91
+ }
92
+ }
93
+ return out;
94
+ }
95
+
96
+ /////////////////////////////////////////////////////////////////////////////////////////////////
97
+
98
+ /// Permits copying static-length vectors into dynamic vectors
99
+ template <typename TensorCoord, int Rank>
100
+ struct coord_to_vector {
101
+
102
+ coord_to_vector(std::vector<int> &vec, TensorCoord const &coord) {
103
+
104
+ vec.at(Rank - 1) = coord[Rank - 1];
105
+ coord_to_vector<TensorCoord, Rank - 1>(vec, coord);
106
+ }
107
+ };
108
+
109
+ /// Permits copying static-length vectors into dynamic vectors
110
+ template <typename TensorCoord>
111
+ struct coord_to_vector<TensorCoord, 1> {
112
+
113
+ coord_to_vector(std::vector<int> &vec, TensorCoord const &coord) {
114
+
115
+ vec.at(0) = coord[0];
116
+ }
117
+ };
118
+
119
+ /// Permits copying static-length vectors into dynamic vectors
120
+ template <typename TensorCoord>
121
+ struct coord_to_vector<TensorCoord, 0> {
122
+
123
+ coord_to_vector(std::vector<int> &vec, TensorCoord const &coord) {
124
+ }
125
+ };
126
+
127
+ /////////////////////////////////////////////////////////////////////////////////////////////////
128
+
129
+ /// Structure representing an element in source memory
130
+ struct Element {
131
+
132
+ std::vector<int> coord; ///< logical coordinate of element (as vector)
133
+ int offset; ///< linear offset from source memory
134
+ int color; ///< enables coloring each element to indicate
135
+
136
+ /// Default ctor
137
+ inline Element(): offset(-1), color(0) { }
138
+
139
+ /// Construct from logical coordinate and initial offset
140
+ inline Element(
141
+ std::vector<int> const &coord_,
142
+ int offset_,
143
+ int color_ = 0
144
+ ):
145
+ coord(coord_), offset(offset_), color(color_) { }
146
+
147
+ /// Returns true if element is in a defined state
148
+ inline bool valid() const {
149
+ return offset >= 0;
150
+ }
151
+ };
152
+
153
+ /////////////////////////////////////////////////////////////////////////////////////////////////
154
+
155
+ /// Visualizes memory layouts by constructing a 'shape'
156
+ template <typename Layout_>
157
+ class VisualizeLayout : public VisualizeLayoutBase {
158
+ public:
159
+
160
+ using Layout = Layout_;
161
+ using TensorCoord = typename Layout::TensorCoord;
162
+ using Stride = typename Layout::Stride;
163
+
164
+ public:
165
+
166
+ Options options;
167
+ Layout layout;
168
+ TensorCoord extent;
169
+ std::vector<Element> elements;
170
+
171
+ public:
172
+
173
+ /// Initializes the problem space
174
+ VisualizeLayout() {
175
+
176
+ }
177
+
178
+ /// visualization method
179
+ bool visualize(Options const &options_) {
180
+
181
+ options = options_;
182
+
183
+ if (options.extent.size() != TensorCoord::kRank) {
184
+
185
+ std::cerr
186
+ << "--extent must have rank " << TensorCoord::kRank
187
+ << " (given: " << options.extent.size() << ")" << std::endl;
188
+
189
+ return false;
190
+ }
191
+
192
+ vector_to_coord<TensorCoord, TensorCoord::kRank>(extent, options.extent);
193
+
194
+ // Construct the layout for a packed tensor
195
+ if (options.stride.empty()) {
196
+
197
+ layout = Layout::packed(extent);
198
+ }
199
+ else if (options.stride.size() != Stride::kRank) {
200
+
201
+ std::cerr
202
+ << "--stride must have rank " << Stride::kRank
203
+ << " (given: " << options.stride.size() << ")" << std::endl;
204
+
205
+ return false;
206
+ }
207
+ else {
208
+ // Stride from
209
+ Stride stride;
210
+ vector_to_coord<Stride, Stride::kRank>(stride, options.stride);
211
+
212
+ layout = Layout(stride);
213
+ }
214
+
215
+ // Resize elements, setting elements to 'undefined' state
216
+ elements.resize(layout.capacity(extent));
217
+
218
+ // enumerate points in tensor space and assign
219
+ cutlass::reference::host::TensorForEachLambda(
220
+ extent,
221
+ [&](TensorCoord coord) {
222
+
223
+ std::vector<int> coord_vec(TensorCoord::kRank, 0);
224
+ coord_to_vector<TensorCoord, TensorCoord::kRank>(coord_vec, coord);
225
+
226
+ int offset = int(layout(coord));
227
+
228
+ if (offset >= int(elements.size())) {
229
+ std::cerr
230
+ << "Layout error - " << coord_vec
231
+ << " is out of range (computed offset: " << offset
232
+ << ", capacity: " << elements.size() << std::endl;
233
+
234
+ throw std::out_of_range("(TensorForEach) layout error - coordinate out of range");
235
+ }
236
+
237
+ elements.at(offset) = Element(coord_vec, offset);
238
+ });
239
+
240
+ return true;
241
+ }
242
+
243
+ /// Verifies the layout satisfies vectorization requirements
244
+ bool verify(bool verbose, std::ostream &out) {
245
+ return true;
246
+ }
247
+
248
+ private:
249
+
250
+ /// returns a pair (is_vectorizable, one_changing_rank) to determine if a
251
+ /// vector exists (consecutive logical coordinates or uniformly invalid)
252
+ /// at the given location.
253
+ std::pair< bool, int > _is_vectorizable(int i) const {
254
+ // (all elements are invalid) or
255
+ // (all elements are valid AND
256
+ // exactly one rank is changing AND
257
+ // elements are consecutive)
258
+
259
+ // Don't need vectorization.
260
+ if (options.vectorize <= 2) return std::make_pair(false, -1);
261
+
262
+ // Boundary check.
263
+ if (i > int(elements.size()) || (i + options.vectorize - 1) > int(elements.size()))
264
+ return std::make_pair(false, -1);
265
+
266
+ // Check if either all elements are valid or invalid.
267
+ bool all_elements_invalid = std::all_of(
268
+ elements.begin() + i, elements.begin() + i + options.vectorize,
269
+ [](Element const &e) { return !e.valid(); });
270
+
271
+ bool all_elements_valid = std::all_of(
272
+ elements.begin() + i, elements.begin() + i + options.vectorize,
273
+ [](Element const &e) { return e.valid(); });
274
+
275
+ if (!all_elements_invalid && !all_elements_valid)
276
+ return std::make_pair(false, -1);
277
+
278
+ // From here, it is vectorizable.
279
+ if (all_elements_invalid) return std::make_pair(true, -1);
280
+
281
+ // Check if only exactly one rank is changing.
282
+ int one_changing_rank = -1;
283
+ for (int j = 0; j < options.vectorize; ++j) {
284
+ for (int r = 0; r < TensorCoord::kRank; ++r) {
285
+ if (elements.at(i + j).coord.at(r) != elements.at(i).coord.at(r)) {
286
+ if (one_changing_rank == -1) {
287
+ one_changing_rank = r;
288
+ } else if (one_changing_rank != r) {
289
+ return std::make_pair(false, -1);
290
+ }
291
+ }
292
+ }
293
+ }
294
+
295
+ return std::make_pair(true, one_changing_rank);
296
+ }
297
+
298
+ /// Prints a vector of elements
299
+ void _print_vector(std::ostream &out, int i, int one_changing_rank) {
300
+ Element const &base_element = elements.at(i);
301
+ if (base_element.valid()) {
302
+ out << "(";
303
+ for (int r = 0; r < TensorCoord::kRank; ++r) {
304
+ if (r) {
305
+ out << ", ";
306
+ }
307
+
308
+ if (r == one_changing_rank) {
309
+ out
310
+ << base_element.coord.at(r)
311
+ << ".."
312
+ << (base_element.coord.at(r) + options.vectorize - 1);
313
+ }
314
+ else {
315
+ out << base_element.coord.at(r);
316
+ }
317
+ }
318
+ out << ")";
319
+ }
320
+ else {
321
+ out << " ";
322
+ }
323
+ }
324
+
325
+ /// Prints a single element
326
+ void _print_element(std::ostream &out, int k) {
327
+ Element const &element = elements.at(k);
328
+ if (element.valid()) {
329
+ out << "(";
330
+ for (int v = 0; v < TensorCoord::kRank; ++v) {
331
+ out << (v ? ", " : "") << element.coord.at(v);
332
+ }
333
+ out << ")";
334
+ }
335
+ else {
336
+ out << " ";
337
+ }
338
+ }
339
+
340
+ public:
341
+
342
+ /// Pretty-prints the layout to the console
343
+ void print_csv(std::ostream &out, char delim = '|', char new_line = '\n') {
344
+ int row = -1;
345
+
346
+ for (int i = 0; i < int(elements.size()); i += options.vectorize) {
347
+ if (i % options.output_shape.at(0)) {
348
+ out << delim;
349
+ }
350
+ else {
351
+ if (row >= 0) {
352
+ out << new_line;
353
+ }
354
+ ++row;
355
+ if (row == options.output_shape.at(1)) {
356
+ out << new_line;
357
+ row = 0;
358
+ }
359
+ }
360
+
361
+ auto is_vector = _is_vectorizable(i);
362
+
363
+ if (is_vector.first) {
364
+ _print_vector(out, i, is_vector.second); // print a vector starting at element i
365
+ }
366
+ else {
367
+ for (int j = 0; j < options.vectorize; ++j) { // print individual elements [i..i+j)
368
+ _print_element(out, i + j);
369
+ }
370
+ }
371
+ }
372
+
373
+ out << new_line << std::flush;
374
+ }
375
+
376
+ /// Help message
377
+ virtual std::ostream &print_help(std::ostream &out) {
378
+ out << "TensorCoord rank " << TensorCoord::kRank << ", Stride rank: " << Stride::kRank;
379
+ return out;
380
+ }
381
+ };
382
+
383
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h ADDED
@@ -0,0 +1,719 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ #pragma once
33
+
34
+ #include <iostream>
35
+ #include <fstream>
36
+ #include <sstream>
37
+
38
+ #include "cutlass/cutlass.h"
39
+
40
+ #include "cutlass/conv/device/implicit_gemm_convolution.h"
41
+ #include "cutlass/reduction/device/reduce_split_k.h"
42
+ #include "cutlass/reduction/thread/reduction_operators.h"
43
+
44
+ #include "cutlass/util/host_tensor.h"
45
+ #include "cutlass/util/reference/host/tensor_fill.h"
46
+ #include "cutlass/util/reference/device/tensor_compare.h"
47
+ #include "cutlass/util/reference/host/tensor_compare.h"
48
+ #include "cutlass/util/reference/host/tensor_norm.h"
49
+
50
+ #include "cutlass/util/reference/host/convolution.h"
51
+ #include "cutlass/util/reference/device/convolution.h"
52
+ #include "cutlass/util/reference/device/tensor_relu.h"
53
+
54
+ #include "cutlass/core_io.h"
55
+ #include "cutlass/util/tensor_view_io.h"
56
+
57
+ #include "reference/device/tensor_scale_bias.h"
58
+ #include "helper.h"
59
+
60
+ #define CHECK_GT(val1, val2) \
61
+ if((val1) <= (val2)) \
62
+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
63
+ #define CHECK_TRUE(val) \
64
+ if(!(val)) \
65
+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
66
+
67
+
68
+ template <typename Conv2d0_, typename Conv2d1_>
69
+ class B2bNonFusedConv2dRun {
70
+ public:
71
+
72
+ using Conv2d0 = Conv2d0_;
73
+ using Conv2d1 = Conv2d1_;
74
+ using ElementAccumulator = typename Conv2d0::ElementAccumulator;
75
+ using ElementCompute = typename Conv2d0::ElementCompute;
76
+
77
+ static cutlass::conv::Operator const kConvolutionalOperator = Conv2d0::kConvolutionalOperator;
78
+ static_assert(kConvolutionalOperator == Conv2d1::kConvolutionalOperator,
79
+ "Fused convolution operators must be the same");
80
+
81
+ public:
82
+
83
+ /// Initialization
84
+ cutlass::Distribution::Kind init_A;
85
+ cutlass::Distribution::Kind init_B;
86
+ cutlass::Distribution::Kind init_C;
87
+ cutlass::Distribution::Kind init_Bias;
88
+ uint64_t seed;
89
+
90
+ cutlass::HostTensor<typename Conv2d0::ElementA, typename Conv2d0::LayoutA> tensor_A0;
91
+ cutlass::HostTensor<typename Conv2d0::ElementB, typename Conv2d0::LayoutB> tensor_B0;
92
+ cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_C0;
93
+ cutlass::HostTensor<typename Conv2d0::ElementCompute, typename Conv2d0::LayoutC> tensor_Bias0;
94
+ cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_computed;
95
+ cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_reference;
96
+
97
+ cutlass::HostTensor<typename Conv2d1::ElementB, typename Conv2d1::LayoutB> tensor_B1;
98
+ cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_C1;
99
+ cutlass::HostTensor<typename Conv2d1::ElementCompute, typename Conv2d0::LayoutC> tensor_Bias1;
100
+ cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_computed;
101
+ cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_reference;
102
+
103
+
104
+ public:
105
+
106
+ B2bNonFusedConv2dRun(
107
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
108
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
109
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
110
+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
111
+ uint64_t seed_ = 2080
112
+ ):
113
+ init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) {
114
+
115
+ }
116
+
117
+ /// Helper to initialize a tensor view
118
+ template <typename Element, typename Layout>
119
+ void initialize_tensor(
120
+ cutlass::TensorView<Element, Layout> view,
121
+ cutlass::Distribution::Kind dist_kind,
122
+ uint64_t seed) {
123
+
124
+ if (dist_kind == cutlass::Distribution::Uniform) {
125
+
126
+ int scope;
127
+ int bits = cutlass::sizeof_bits<Element>::value;
128
+
129
+ if (bits <= 16) {
130
+ scope = 2;
131
+ }
132
+ else {
133
+ scope = 8;
134
+ }
135
+ cutlass::reference::host::TensorFillRandomUniform(
136
+ view, seed, scope, -scope, 0);
137
+ }
138
+ else if (dist_kind == cutlass::Distribution::Identity) {
139
+
140
+ cutlass::reference::host::TensorFillIdentity(view);
141
+ }
142
+ else if (dist_kind == cutlass::Distribution::Gaussian) {
143
+
144
+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
145
+ }
146
+ else if (dist_kind == cutlass::Distribution::Sequential) {
147
+
148
+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
149
+ }
150
+ else if (dist_kind == cutlass::Distribution::AllZeros) {
151
+ cutlass::reference::host::TensorFill(view, Element(0));
152
+ }
153
+ else if (dist_kind == cutlass::Distribution::AllOnes) {
154
+ cutlass::reference::host::TensorFill(view, Element(1));
155
+ }
156
+ else {
157
+ std::cerr << "Not implemented\n";
158
+ }
159
+ }
160
+
161
+ void initialize(
162
+ cutlass::conv::Conv2dProblemSize const &problem_size_0,
163
+ cutlass::conv::Conv2dProblemSize const &problem_size_1,
164
+ uint64_t seed = 2019) {
165
+
166
+ tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0));
167
+ tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
168
+ tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
169
+ tensor_Bias0.resize({1, 1, 1, problem_size_0.K});
170
+ tensor_D0_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
171
+ tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
172
+ tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
173
+ tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
174
+ tensor_Bias1.resize({1, 1, 1, problem_size_1.K});
175
+ tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
176
+ tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
177
+
178
+ initialize_tensor(tensor_A0.host_view(), init_A, seed);
179
+ initialize_tensor(tensor_B0.host_view(), init_B, seed * 17);
180
+ initialize_tensor(tensor_C0.host_view(), init_C, seed * 39);
181
+ initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83);
182
+ initialize_tensor(tensor_B1.host_view(), init_B, seed * 18);
183
+ initialize_tensor(tensor_C1.host_view(), init_C, seed * 40);
184
+ initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84);
185
+
186
+ tensor_A0.sync_device();
187
+ tensor_B0.sync_device();
188
+ tensor_C0.sync_device();
189
+ tensor_Bias0.sync_device();
190
+ tensor_D0_computed.sync_device();
191
+ tensor_D0_reference.sync_device();
192
+ tensor_B1.sync_device();
193
+ tensor_C1.sync_device();
194
+ tensor_Bias1.sync_device();
195
+ tensor_D1_computed.sync_device();
196
+ tensor_D1_reference.sync_device();
197
+ }
198
+
199
+ /// Executes one test
200
+ bool run(
201
+ cutlass::conv::Conv2dProblemSize const &problem_size_0,
202
+ cutlass::conv::Conv2dProblemSize const &problem_size_1,
203
+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
204
+ ElementCompute alpha0 = ElementCompute(1),
205
+ ElementCompute beta0 = ElementCompute(0),
206
+ ElementCompute alpha1 = ElementCompute(1),
207
+ ElementCompute beta1 = ElementCompute(0),
208
+ bool relu = true,
209
+ int warm_ups = 1,
210
+ int runs = 100) {
211
+
212
+ initialize(problem_size_0, problem_size_1);
213
+
214
+ // configure the operator
215
+ Conv2d0 conv2d_op_0;
216
+ Conv2d1 conv2d_op_1;
217
+
218
+ typename Conv2d0::Arguments conv2d_args_0(
219
+ problem_size_0,
220
+ tensor_A0.device_ref(),
221
+ tensor_B0.device_ref(),
222
+ {tensor_Bias0.device_data(), typename Conv2d0::LayoutC::Stride(0)},
223
+ tensor_D0_computed.device_ref(),
224
+ {alpha0, beta0},
225
+ split_k_mode
226
+ );
227
+ typename Conv2d1::Arguments conv2d_args_1(
228
+ problem_size_1,
229
+ tensor_D0_computed.device_ref(),
230
+ tensor_B1.device_ref(),
231
+ {tensor_Bias1.device_data(), typename Conv2d1::LayoutC::Stride(0)},
232
+ tensor_D1_computed.device_ref(),
233
+ {alpha1, beta1},
234
+ split_k_mode
235
+ );
236
+
237
+
238
+ cutlass::Status status = conv2d_op_0.initialize(conv2d_args_0);
239
+
240
+ CUTLASS_CHECK(status);
241
+
242
+ status = conv2d_op_1.initialize(conv2d_args_1);
243
+
244
+ CUTLASS_CHECK(status);
245
+
246
+ for(int i = 0; i < warm_ups; i++) {
247
+ status = conv2d_op_0();
248
+ CUTLASS_CHECK(status);
249
+ status = conv2d_op_1();
250
+ CUTLASS_CHECK(status);
251
+ }
252
+
253
+ //
254
+ // Run Conv2d
255
+ //
256
+ cudaEvent_t start, stop1, stop2;
257
+ cudaEventCreate(&start);
258
+ cudaEventCreate(&stop1);
259
+ cudaEventCreate(&stop2);
260
+
261
+ cudaEventRecord(start);
262
+
263
+
264
+ for(int i = 0; i < runs; i++) {
265
+ // run conv2d operator
266
+ status = conv2d_op_0();
267
+ CUTLASS_CHECK(status);
268
+ }
269
+ cudaEventRecord(stop1);
270
+
271
+ for(int i = 0; i < runs; i++) {
272
+ // run conv2d operator
273
+ status = conv2d_op_1();
274
+ CUTLASS_CHECK(status);
275
+ }
276
+ cudaEventRecord(stop2);
277
+ cudaDeviceSynchronize();
278
+ float conv2d0Time, conv2d1Time, totalTime;
279
+ cudaEventElapsedTime(&conv2d0Time, start, stop1);
280
+ cudaEventElapsedTime(&conv2d1Time, stop1, stop2);
281
+ cudaEventElapsedTime(&totalTime, start, stop2);
282
+ std::cout << "conv2d 0 time " << conv2d0Time / (float)runs << " ms\n";
283
+ std::cout << "conv2d 1 time " << conv2d1Time / (float)runs << " ms\n";
284
+ std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n";
285
+
286
+ tensor_D0_computed.sync_host();
287
+ tensor_D1_computed.sync_host();
288
+
289
+ bool passed = false;
290
+
291
+ cutlass::reference::device::Conv2d<
292
+ typename Conv2d0::ElementA,
293
+ typename Conv2d0::LayoutA,
294
+ typename Conv2d0::ElementB,
295
+ typename Conv2d0::LayoutB,
296
+ typename Conv2d0::ElementC,
297
+ typename Conv2d0::LayoutC,
298
+ ElementCompute,
299
+ ElementAccumulator
300
+ >(
301
+ kConvolutionalOperator,
302
+ problem_size_0,
303
+ tensor_A0.device_ref(),
304
+ tensor_B0.device_ref(),
305
+ {tensor_Bias0.device_data(), typename Conv2d0::LayoutC::Stride(0)},
306
+ tensor_D0_reference.device_ref(),
307
+ alpha0,
308
+ beta0);
309
+
310
+ if(relu) {
311
+ cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view());
312
+ }
313
+
314
+ cutlass::reference::device::Conv2d<
315
+ typename Conv2d1::ElementA,
316
+ typename Conv2d1::LayoutA,
317
+ typename Conv2d1::ElementB,
318
+ typename Conv2d1::LayoutB,
319
+ typename Conv2d1::ElementC,
320
+ typename Conv2d1::LayoutC,
321
+ ElementCompute,
322
+ ElementAccumulator
323
+ >(
324
+ kConvolutionalOperator,
325
+ problem_size_1,
326
+ tensor_D0_reference.device_ref(),
327
+ tensor_B1.device_ref(),
328
+ {tensor_Bias1.device_data(), typename Conv2d1::LayoutC::Stride(0)},
329
+ tensor_D1_reference.device_ref(),
330
+ alpha1,
331
+ beta1);
332
+
333
+ if(relu) {
334
+ cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view());
335
+ }
336
+
337
+ cudaError_t result = cudaDeviceSynchronize();
338
+ CHECK_TRUE(result == cudaSuccess);
339
+
340
+ // sync host (copy device data to host) for dumping error output in case of mismatches
341
+ tensor_D0_reference.sync_host();
342
+ tensor_D1_reference.sync_host();
343
+
344
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_computed.host_view()), 0);
345
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0);
346
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0);
347
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0);
348
+
349
+ passed = cutlass::reference::host::TensorEquals(
350
+ tensor_D1_computed.host_view(),
351
+ tensor_D1_reference.host_view());
352
+
353
+ CHECK_TRUE(passed);
354
+
355
+ if (!passed) {
356
+ std::stringstream fname;
357
+
358
+ fname << "error_B2bImplicitGemm_device_nonfused.txt";
359
+ std::cerr << "Dumping results in " << fname.str() << "\n";
360
+
361
+ std::ofstream results(fname.str());
362
+
363
+ results << problem_size_0 << std::endl;
364
+ results << problem_size_1 << std::endl;
365
+
366
+ results
367
+ << "\nA0:\n" << tensor_A0.host_view() << "\n"
368
+ << "\nB0:\n" << tensor_B0.host_view() << "\n"
369
+ << "\nC0:\n" << tensor_C0.host_view() << "\n"
370
+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
371
+ << "\nD0 reference:\n" << tensor_D0_reference.host_view() << "\n"
372
+ << "\nD0 computed:\n" << tensor_D0_computed.host_view() << "\n"
373
+ << "\nB1:\n" << tensor_B1.host_view() << "\n"
374
+ << "\nC1:\n" << tensor_C1.host_view() << "\n"
375
+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
376
+ << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n"
377
+ << "\nD1 computed:\n" << tensor_D1_computed.host_view();
378
+
379
+
380
+ }
381
+
382
+ return passed;
383
+ }
384
+
385
+ };
386
+
387
+ template <typename B2bConv2d_>
388
+ class B2bFusedConv2dRun {
389
+ public:
390
+
391
+ using B2bConv2d = B2bConv2d_;
392
+ using ElementAccumulator = typename B2bConv2d::ElementAccumulator;
393
+ using ElementCompute = typename B2bConv2d::ElementCompute;
394
+
395
+ static cutlass::conv::Operator const kConvolutionalOperator = B2bConv2d::kConvolutionalOperator;
396
+
397
+ public:
398
+
399
+ /// Initialization
400
+ cutlass::Distribution::Kind init_A;
401
+ cutlass::Distribution::Kind init_B;
402
+ cutlass::Distribution::Kind init_C;
403
+ cutlass::Distribution::Kind init_Scale;
404
+ cutlass::Distribution::Kind init_Bias;
405
+ uint64_t seed;
406
+
407
+ cutlass::HostTensor<typename B2bConv2d::ElementA, typename B2bConv2d::LayoutA> tensor_A0;
408
+ cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B0;
409
+ cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C0;
410
+ cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Scale0;
411
+ cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Bias0;
412
+ cutlass::HostTensor<ElementAccumulator, typename B2bConv2d::LayoutC> tensor_Z0_reference;
413
+ cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D0_reference;
414
+
415
+ cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1;
416
+ cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C1;
417
+ cutlass::HostTensor<typename B2bConv2d::ElementCompute, typename B2bConv2d::LayoutC> tensor_Bias1;
418
+ cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_computed;
419
+ cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_reference;
420
+
421
+
422
+ public:
423
+
424
+ B2bFusedConv2dRun(
425
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
426
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
427
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
428
+ cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
429
+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
430
+ uint64_t seed_ = 2080
431
+ ):
432
+ init_A(init_A_), init_B(init_B_), init_C(init_C_),
433
+ init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) {
434
+
435
+ }
436
+
437
+ /// Helper to initialize a tensor view
438
+ template <typename Element, typename Layout>
439
+ void initialize_tensor(
440
+ cutlass::TensorView<Element, Layout> view,
441
+ cutlass::Distribution::Kind dist_kind,
442
+ uint64_t seed) {
443
+
444
+ if (dist_kind == cutlass::Distribution::Uniform) {
445
+
446
+ int scope;
447
+ int bits = cutlass::sizeof_bits<Element>::value;
448
+
449
+ if (bits <= 16) {
450
+ scope = 2;
451
+ }
452
+ else {
453
+ scope = 8;
454
+ }
455
+ cutlass::reference::host::TensorFillRandomUniform(
456
+ view, seed, scope, -scope, 0);
457
+ }
458
+ else if (dist_kind == cutlass::Distribution::Identity) {
459
+
460
+ cutlass::reference::host::TensorFillIdentity(view);
461
+ }
462
+ else if (dist_kind == cutlass::Distribution::Gaussian) {
463
+
464
+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
465
+ }
466
+ else if (dist_kind == cutlass::Distribution::Sequential) {
467
+
468
+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
469
+ }
470
+ else if (dist_kind == cutlass::Distribution::AllZeros) {
471
+ cutlass::reference::host::TensorFill(view, Element(0));
472
+ }
473
+ else if (dist_kind == cutlass::Distribution::AllOnes) {
474
+ cutlass::reference::host::TensorFill(view, Element(1));
475
+ }
476
+ else {
477
+ }
478
+ }
479
+
480
+ void initialize(
481
+ cutlass::conv::Conv2dProblemSize const &problem_size_0,
482
+ cutlass::conv::Conv2dProblemSize const &problem_size_1,
483
+ ElementCompute alpha0,
484
+ ElementCompute alpha1,
485
+ uint64_t seed = 2019) {
486
+
487
+ tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0));
488
+ tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
489
+ tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
490
+ if(alpha0 == ElementCompute(0)) //per-channel scale
491
+ tensor_Scale0.resize({1, problem_size_0.K});
492
+ tensor_Bias0.resize({1, problem_size_0.K});
493
+ tensor_Z0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
494
+ tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
495
+ tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
496
+ tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
497
+ tensor_Bias1.resize({1, 1, 1, problem_size_1.K});
498
+ tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
499
+ tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
500
+
501
+ initialize_tensor(tensor_A0.host_view(), init_A, seed);
502
+ initialize_tensor(tensor_B0.host_view(), init_B, seed * 17);
503
+ initialize_tensor(tensor_C0.host_view(), init_C, seed * 39);
504
+ if(alpha0 == ElementCompute(0)) //per-channel scale
505
+ initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed * 61);
506
+ initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83);
507
+ initialize_tensor(tensor_B1.host_view(), init_B, seed * 18);
508
+ initialize_tensor(tensor_C1.host_view(), init_C, seed * 40);
509
+ initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84);
510
+
511
+ tensor_A0.sync_device();
512
+ tensor_B0.sync_device();
513
+ tensor_C0.sync_device();
514
+ if(alpha0 == ElementCompute(0)) //per-channel scale
515
+ tensor_Scale0.sync_device();
516
+ tensor_Bias0.sync_device();
517
+ tensor_D0_reference.sync_device();
518
+ tensor_B1.sync_device();
519
+ tensor_C1.sync_device();
520
+ tensor_Bias1.sync_device();
521
+ tensor_D1_computed.sync_device();
522
+ tensor_D1_reference.sync_device();
523
+ }
524
+
525
+ /// Executes one test
526
+ bool run(
527
+ cutlass::conv::Conv2dProblemSize const &problem_size_0,
528
+ cutlass::conv::Conv2dProblemSize const &problem_size_1,
529
+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
530
+ ElementCompute alpha0 = ElementCompute(1),
531
+ ElementCompute beta0 = ElementCompute(0),
532
+ ElementCompute alpha1 = ElementCompute(1),
533
+ ElementCompute beta1 = ElementCompute(0),
534
+ bool relu = true,
535
+ int warm_ups = 1,
536
+ int runs = 100) {
537
+
538
+ initialize(problem_size_0, problem_size_1, alpha0, alpha1);
539
+
540
+ // configure the operator
541
+ B2bConv2d b2b_conv2d_op;
542
+
543
+ typename B2bConv2d::Arguments b2b_conv2d_args(
544
+ problem_size_0,
545
+ problem_size_1,
546
+ tensor_A0.device_ref(),
547
+ tensor_B0.device_ref(),
548
+ tensor_C0.device_ref(),
549
+ tensor_Scale0.device_ref(),
550
+ tensor_Bias0.device_ref(),
551
+ tensor_B1.device_ref(),
552
+ {tensor_Bias1.device_data(), typename B2bConv2d::LayoutC::Stride(0)},
553
+ tensor_D1_computed.device_ref(),
554
+ {alpha0, beta0},
555
+ {alpha1, beta1},
556
+ split_k_mode
557
+ );
558
+
559
+ cutlass::Status status = b2b_conv2d_op.can_implement(b2b_conv2d_args);
560
+
561
+ if(status != cutlass::Status::kSuccess) {
562
+ std::cout << "Problem sizes not supported.\n"
563
+ << "Requirments:\n"
564
+ << " problem_size_0.N*P*Q = problem_size_1.N*P*Q\n"
565
+ << " problem_size_0.K = problem_size_1.C\n"
566
+ << " problem_size_1.R = problem_size_1.S = 1\n"
567
+ << " ThreadblockShape0::kN = problem_size_0.K\n"
568
+ << " ThreadblockShape1::kN = problem_size_1.K" << std::endl;
569
+ }
570
+
571
+ CUTLASS_CHECK(status);
572
+
573
+ status = b2b_conv2d_op.initialize(b2b_conv2d_args);
574
+
575
+ CUTLASS_CHECK(status);
576
+
577
+ for(int i = 0; i < warm_ups; i++) {
578
+ status = b2b_conv2d_op();
579
+ CUTLASS_CHECK(status);
580
+ }
581
+
582
+ //
583
+ // Run the Conv2d
584
+ //
585
+
586
+ cudaEvent_t start, stop;
587
+ cudaEventCreate(&start);
588
+ cudaEventCreate(&stop);
589
+
590
+ cudaEventRecord(start);
591
+
592
+ for(int i = 0; i < runs; i++) {
593
+
594
+ // run conv2d operator
595
+ status = b2b_conv2d_op();
596
+ CUTLASS_CHECK(status);
597
+ }
598
+
599
+ cudaEventRecord(stop);
600
+ cudaDeviceSynchronize();
601
+ float conv2dTime;
602
+ cudaEventElapsedTime(&conv2dTime, start, stop);
603
+ std::cout << "Fusion time " << conv2dTime / (float)runs << " ms\n";
604
+
605
+ tensor_D1_computed.sync_host();
606
+
607
+ bool passed = false;
608
+
609
+ cutlass::reference::device::Conv2d<
610
+ typename B2bConv2d::ElementA,
611
+ typename B2bConv2d::LayoutA,
612
+ typename B2bConv2d::ElementB,
613
+ typename B2bConv2d::LayoutB,
614
+ ElementAccumulator,
615
+ typename B2bConv2d::LayoutC,
616
+ ElementAccumulator,
617
+ ElementAccumulator
618
+ >(
619
+ kConvolutionalOperator,
620
+ problem_size_0,
621
+ tensor_A0.device_ref(),
622
+ tensor_B0.device_ref(),
623
+ tensor_Z0_reference.device_ref(),
624
+ tensor_Z0_reference.device_ref(),
625
+ ElementAccumulator(1), // intermediate alpha = 1
626
+ ElementAccumulator(0) // beta = 0
627
+ );
628
+
629
+ cutlass::reference::device::TensorScaleBiasConv2d<
630
+ ElementAccumulator,
631
+ typename B2bConv2d::ElementC,
632
+ typename B2bConv2d::LayoutC,
633
+ ElementCompute,
634
+ typename B2bConv2d::LayoutScaleBias
635
+ >(
636
+ problem_size_0,
637
+ tensor_Z0_reference.device_ref(),
638
+ tensor_D0_reference.device_ref(),
639
+ alpha0,
640
+ tensor_Scale0.device_ref(),
641
+ tensor_Bias0.device_ref()
642
+ );
643
+
644
+ if(relu) {
645
+ cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view());
646
+ }
647
+
648
+ cutlass::reference::device::Conv2d<
649
+ typename B2bConv2d::ElementA,
650
+ typename B2bConv2d::LayoutA,
651
+ typename B2bConv2d::ElementB,
652
+ typename B2bConv2d::LayoutB,
653
+ typename B2bConv2d::ElementC,
654
+ typename B2bConv2d::LayoutC,
655
+ ElementCompute,
656
+ ElementAccumulator
657
+ >(
658
+ kConvolutionalOperator,
659
+ problem_size_1,
660
+ tensor_D0_reference.device_ref(),
661
+ tensor_B1.device_ref(),
662
+ {tensor_Bias1.device_data(), typename B2bConv2d::LayoutC::Stride(0)},
663
+ tensor_D1_reference.device_ref(),
664
+ alpha1,
665
+ beta1);
666
+
667
+ if(relu) {
668
+ cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view());
669
+ }
670
+
671
+ cudaError_t result = cudaDeviceSynchronize();
672
+ CHECK_TRUE(result == cudaSuccess);
673
+
674
+ // sync host (copy device data to host) for dumping error output in case of mismatches
675
+ tensor_D0_reference.sync_host();
676
+ tensor_D1_reference.sync_host();
677
+
678
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0);
679
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0);
680
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0);
681
+
682
+ passed = cutlass::reference::host::TensorEquals(
683
+ tensor_D1_computed.host_view(),
684
+ tensor_D1_reference.host_view());
685
+
686
+ CHECK_TRUE(passed);
687
+
688
+ if (!passed) {
689
+ std::stringstream fname;
690
+
691
+ fname << "error_B2bImplicitGemm_device_fused.txt";
692
+ std::cerr << "Dumping results in " << fname.str() << "\n";
693
+
694
+ std::ofstream results(fname.str());
695
+
696
+ results << problem_size_0 << std::endl;
697
+ results << problem_size_1 << std::endl;
698
+
699
+ results
700
+ << "\nA0:\n" << tensor_A0.host_view() << "\n"
701
+ << "\nB0:\n" << tensor_B0.host_view() << "\n"
702
+ << "\nC0:\n" << tensor_C0.host_view() << "\n"
703
+ << "\nScale0:\n" << tensor_Scale0.host_view() << "\n"
704
+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
705
+ << "\nB1:\n" << tensor_B1.host_view() << "\n"
706
+ << "\nC1:\n" << tensor_C1.host_view() << "\n"
707
+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
708
+ << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n"
709
+ << "\nD1 computed:\n" << tensor_D1_computed.host_view();
710
+
711
+
712
+ }
713
+
714
+ return passed;
715
+ }
716
+
717
+ };
718
+
719
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_gemm_run.h ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ #pragma once
32
+
33
+ #include <iostream>
34
+ #include <fstream>
35
+ #include <sstream>
36
+
37
+ #include "cutlass/util/host_tensor.h"
38
+ #include "cutlass/util/tensor_view_io.h"
39
+ #include "cutlass/util/distribution.h"
40
+ #include "cutlass/util/reference/host/tensor_fill.h"
41
+ #include "cutlass/util/reference/host/tensor_copy.h"
42
+ #include "cutlass/util/reference/host/tensor_compare.h"
43
+ #include "cutlass/util/reference/host/tensor_norm.h"
44
+ #include "cutlass/util/reference/device/gemm.h"
45
+ #include "cutlass/util/reference/device/gemm_complex.h"
46
+ #include "cutlass/util/reference/device/tensor_relu.h"
47
+
48
+ #include "reference/device/tensor_scale_bias.h"
49
+ #include "helper.h"
50
+
51
+ #define CHECK_GT(val1, val2) \
52
+ if((val1) <= (val2)) \
53
+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
54
+ #define CHECK_TRUE(val) \
55
+ if(!(val)) \
56
+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
57
+
58
+ ////////////////////////////////////////////////////////////////////////////////
59
+
60
+ template <typename Gemm0_, typename Gemm1_>
61
+ struct B2bNonFusedGemmRun
62
+ {
63
+
64
+ using Gemm0 = Gemm0_;
65
+ using Gemm1 = Gemm1_;
66
+ using ElementAccumulator = typename Gemm0::ElementAccumulator;
67
+ using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute;
68
+
69
+ /// Initialization
70
+ cutlass::Distribution::Kind init_A;
71
+ cutlass::Distribution::Kind init_B;
72
+ cutlass::Distribution::Kind init_C;
73
+ cutlass::Distribution::Kind init_Bias;
74
+ uint64_t seed;
75
+
76
+ //
77
+ // Methods
78
+ //
79
+
80
+ B2bNonFusedGemmRun(
81
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
82
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
83
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
84
+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
85
+ uint64_t seed_ = 2080
86
+ ):
87
+ init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { }
88
+
89
+ /// Helper to initialize a tensor view
90
+ template <typename Element, typename Layout>
91
+ bool initialize_tensor(
92
+ cutlass::TensorView<Element, Layout> view,
93
+ cutlass::Distribution::Kind dist_kind,
94
+ uint64_t seed) {
95
+
96
+ if (dist_kind == cutlass::Distribution::Uniform) {
97
+
98
+ cutlass::reference::host::TensorFillRandomUniform(
99
+ view, seed, 2, -2, 0);
100
+ }
101
+ else if (dist_kind == cutlass::Distribution::Identity) {
102
+
103
+ cutlass::reference::host::TensorFillIdentity(view);
104
+ }
105
+ else if (dist_kind == cutlass::Distribution::Gaussian) {
106
+
107
+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
108
+ }
109
+ else if (dist_kind == cutlass::Distribution::Sequential) {
110
+
111
+ cutlass::reference::host::BlockFillSequential(
112
+ view.data(), view.capacity());
113
+ }
114
+ else if (dist_kind == cutlass::Distribution::AllZeros) {
115
+ cutlass::reference::host::TensorFill(view, Element(0));
116
+ }
117
+ else if (dist_kind == cutlass::Distribution::AllOnes) {
118
+ cutlass::reference::host::TensorFill(view, Element(1));
119
+ }
120
+ else {
121
+ std::cerr << "Not implemented\n";
122
+ return false;
123
+ }
124
+
125
+ return true;
126
+ }
127
+
128
+
129
+
130
+
131
+ /// Executes one test
132
+ bool run(
133
+ cutlass::gemm::GemmCoord problem_size_0,
134
+ cutlass::gemm::GemmCoord problem_size_1,
135
+ ElementCompute alpha0 = ElementCompute(1),
136
+ ElementCompute beta0 = ElementCompute(0),
137
+ ElementCompute alpha1 = ElementCompute(1),
138
+ ElementCompute beta1 = ElementCompute(0),
139
+ bool relu = true,
140
+ int warm_ups = 1,
141
+ int runs = 100) {
142
+
143
+ //
144
+ // Allocate the GEMM workspace
145
+ //
146
+
147
+ cutlass::HostTensor<
148
+ typename Gemm0::ElementA,
149
+ typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
150
+
151
+ cutlass::HostTensor<
152
+ typename Gemm0::ElementB,
153
+ typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
154
+
155
+ cutlass::HostTensor<
156
+ typename Gemm0::ElementC,
157
+ typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
158
+
159
+ cutlass::HostTensor<
160
+ ElementCompute,
161
+ typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
162
+
163
+ cutlass::HostTensor<
164
+ typename Gemm0::ElementC,
165
+ typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
166
+
167
+ cutlass::HostTensor<
168
+ typename Gemm0::ElementC,
169
+ typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
170
+
171
+ cutlass::HostTensor<
172
+ typename Gemm1::ElementB,
173
+ typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
174
+
175
+ cutlass::HostTensor<
176
+ typename Gemm1::ElementC,
177
+ typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
178
+
179
+ cutlass::HostTensor<
180
+ ElementCompute,
181
+ typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
182
+
183
+ cutlass::HostTensor<
184
+ typename Gemm1::ElementC,
185
+ typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
186
+
187
+ cutlass::HostTensor<
188
+ typename Gemm1::ElementC,
189
+ typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
190
+
191
+
192
+ CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
193
+ CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
194
+ CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
195
+ CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014));
196
+ CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
197
+ CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
198
+ CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013));
199
+
200
+ cutlass::reference::host::TensorFill(
201
+ tensor_D0.host_view());
202
+ cutlass::reference::host::TensorFill(
203
+ tensor_D1.host_view());
204
+ cutlass::reference::host::TensorFill(
205
+ reference_D0.host_view());
206
+ cutlass::reference::host::TensorFill(
207
+ reference_D1.host_view());
208
+
209
+ tensor_A0.sync_device();
210
+ tensor_B0.sync_device();
211
+ tensor_C0.sync_device();
212
+ tensor_Bias0.sync_device();
213
+ tensor_D0.sync_device();
214
+ tensor_B1.sync_device();
215
+ tensor_C1.sync_device();
216
+ tensor_Bias1.sync_device();
217
+ tensor_D1.sync_device();
218
+ reference_D0.sync_device();
219
+ reference_D1.sync_device();
220
+
221
+ //
222
+ // Initialize the GEMM operator
223
+ //
224
+
225
+ typename Gemm0::Arguments arguments_0{
226
+ problem_size_0,
227
+ tensor_A0.device_ref(),
228
+ tensor_B0.device_ref(),
229
+ {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
230
+ tensor_D0.device_ref(),
231
+ {alpha0, beta0}
232
+ };
233
+
234
+ typename Gemm1::Arguments arguments_1{
235
+ problem_size_1,
236
+ tensor_D0.device_ref(),
237
+ tensor_B1.device_ref(),
238
+ {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
239
+ tensor_D1.device_ref(),
240
+ {alpha1, beta1}
241
+ };
242
+
243
+
244
+ Gemm0 gemm_op_0;
245
+ Gemm1 gemm_op_1;
246
+
247
+ cutlass::Status status = gemm_op_0.initialize(arguments_0);
248
+
249
+ CUTLASS_CHECK(status);
250
+
251
+ status = gemm_op_1.initialize(arguments_1);
252
+
253
+ CUTLASS_CHECK(status);
254
+
255
+ for(int i = 0; i < warm_ups; i++) {
256
+ status = gemm_op_0();
257
+ CUTLASS_CHECK(status);
258
+ status = gemm_op_1();
259
+ CUTLASS_CHECK(status);
260
+ }
261
+
262
+ //
263
+ // Run the GEMM
264
+ //
265
+ cudaEvent_t start, stop1, stop2;
266
+ cudaEventCreate(&start);
267
+ cudaEventCreate(&stop1);
268
+ cudaEventCreate(&stop2);
269
+
270
+ cudaEventRecord(start);
271
+
272
+ for(int i = 0; i < runs; i++) {
273
+ status = gemm_op_0();
274
+
275
+ CUTLASS_CHECK(status);
276
+ }
277
+ cudaEventRecord(stop1);
278
+ for(int i = 0; i < runs; i++) {
279
+ status = gemm_op_1();
280
+
281
+ CUTLASS_CHECK(status);
282
+ }
283
+
284
+ cudaEventRecord(stop2);
285
+ cudaDeviceSynchronize();
286
+ float gemm0Time, gemm1Time, totalTime;
287
+ cudaEventElapsedTime(&gemm0Time, start, stop1);
288
+ cudaEventElapsedTime(&gemm1Time, stop1, stop2);
289
+ cudaEventElapsedTime(&totalTime, start, stop2);
290
+ std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n";
291
+ std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n";
292
+ std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n";
293
+
294
+ tensor_D0.sync_host();
295
+ tensor_D1.sync_host();
296
+
297
+ //
298
+ // Verify
299
+ //
300
+ cutlass::reference::device::Gemm<
301
+ typename Gemm0::ElementA, typename Gemm0::LayoutA,
302
+ typename Gemm0::ElementB, typename Gemm0::LayoutB,
303
+ typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute,
304
+ ElementAccumulator, typename Gemm0::Operator>
305
+ reference_gemm_0;
306
+
307
+ cutlass::reference::device::Gemm<
308
+ typename Gemm1::ElementA, typename Gemm1::LayoutA,
309
+ typename Gemm1::ElementB, typename Gemm1::LayoutB,
310
+ typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute,
311
+ ElementAccumulator, typename Gemm1::Operator>
312
+ reference_gemm_1;
313
+
314
+ reference_gemm_0(
315
+ problem_size_0,
316
+ alpha0,
317
+ tensor_A0.device_ref(),
318
+ tensor_B0.device_ref(),
319
+ beta0,
320
+ {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
321
+ reference_D0.device_ref()
322
+ );
323
+
324
+ if(relu) {
325
+ cutlass::reference::device::TensorReLu(reference_D0.device_view());
326
+ }
327
+
328
+ reference_gemm_1(
329
+ problem_size_1,
330
+ alpha1,
331
+ reference_D0.device_ref(),
332
+ tensor_B1.device_ref(),
333
+ beta1,
334
+ {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
335
+ reference_D1.device_ref()
336
+ );
337
+
338
+ if(relu) {
339
+ cutlass::reference::device::TensorReLu(reference_D1.device_view());
340
+ }
341
+
342
+ // Wait for kernels to finish
343
+ cudaDeviceSynchronize();
344
+ reference_D0.sync_host();
345
+ reference_D1.sync_host();
346
+
347
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
348
+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
349
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
350
+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
351
+
352
+ bool passed = cutlass::reference::host::TensorEquals(
353
+ reference_D1.host_view(),
354
+ tensor_D1.host_view());
355
+
356
+ CHECK_TRUE(passed);
357
+ if (!passed) {
358
+
359
+ std::stringstream fname;
360
+
361
+ fname << "error_B2bGemm_device_nonfused.txt";
362
+ std::cerr << "Dumping results in " << fname.str() << "\n";
363
+
364
+ std::ofstream file(fname.str());
365
+
366
+ file
367
+ << "A0 =\n" << tensor_A0.host_view()
368
+ << "\nB0 =\n" << tensor_B0.host_view()
369
+ << "\nC0 =\n" << tensor_C0.host_view()
370
+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
371
+ << "\nD0 =\n" << tensor_D0.host_view()
372
+ << "\nB1 =\n" << tensor_B1.host_view()
373
+ << "\nC1 =\n" << tensor_C1.host_view()
374
+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
375
+ << "\n\nReference =\n" << reference_D1.host_view()
376
+ << "\nComputed =\n" << tensor_D1.host_view();
377
+ }
378
+ return passed;
379
+ }
380
+ };
381
+
382
+ template <typename B2bGemm_>
383
+ struct B2bFusedGemmRun
384
+ {
385
+
386
+ using B2bGemm = B2bGemm_;
387
+ using ElementAccumulator = typename B2bGemm::ElementAccumulator;
388
+ using ElementCompute = typename B2bGemm::B2bGemmKernel::Epilogue::OutputOp::ElementCompute;
389
+
390
+ /// Initialization
391
+ cutlass::Distribution::Kind init_A;
392
+ cutlass::Distribution::Kind init_B;
393
+ cutlass::Distribution::Kind init_C;
394
+ cutlass::Distribution::Kind init_Scale;
395
+ cutlass::Distribution::Kind init_Bias;
396
+ uint64_t seed;
397
+
398
+ //
399
+ // Methods
400
+ //
401
+
402
+ B2bFusedGemmRun(
403
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
404
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
405
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
406
+ cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
407
+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
408
+ uint64_t seed_ = 2080
409
+ ):
410
+ init_A(init_A_), init_B(init_B_), init_C(init_C_),
411
+ init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
412
+
413
+ /// Helper to initialize a tensor view
414
+ template <typename Element, typename Layout>
415
+ bool initialize_tensor(
416
+ cutlass::TensorView<Element, Layout> view,
417
+ cutlass::Distribution::Kind dist_kind,
418
+ uint64_t seed) {
419
+
420
+ if (dist_kind == cutlass::Distribution::Uniform) {
421
+
422
+ cutlass::reference::host::TensorFillRandomUniform(
423
+ view, seed, 2, -2, 0);
424
+ }
425
+ else if (dist_kind == cutlass::Distribution::Identity) {
426
+
427
+ cutlass::reference::host::TensorFillIdentity(view);
428
+ }
429
+ else if (dist_kind == cutlass::Distribution::Gaussian) {
430
+
431
+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
432
+ }
433
+ else if (dist_kind == cutlass::Distribution::Sequential) {
434
+
435
+ cutlass::reference::host::BlockFillSequential(
436
+ view.data(), view.capacity());
437
+ }
438
+ else if (dist_kind == cutlass::Distribution::AllZeros) {
439
+ cutlass::reference::host::TensorFill(view, Element(0));
440
+ }
441
+ else if (dist_kind == cutlass::Distribution::AllOnes) {
442
+ cutlass::reference::host::TensorFill(view, Element(1));
443
+ }
444
+ else {
445
+ std::cerr << "Not implemented\n";
446
+ return false;
447
+ }
448
+
449
+ return true;
450
+ }
451
+
452
+
453
+
454
+
455
+ /// Executes one test
456
+ bool run(
457
+ cutlass::gemm::GemmCoord problem_size_0,
458
+ cutlass::gemm::GemmCoord problem_size_1,
459
+ ElementCompute alpha0 = ElementCompute(1),
460
+ ElementCompute beta0 = ElementCompute(0),
461
+ ElementCompute alpha1 = ElementCompute(1),
462
+ ElementCompute beta1 = ElementCompute(0),
463
+ cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
464
+
465
+ // batch_count is used as split-k when mode is kGemm according
466
+ // to the GemmUniversal interface
467
+
468
+ int batch_count = 1,
469
+ int64_t batch_stride_A0 = 0,
470
+ int64_t batch_stride_B0 = 0,
471
+ int64_t batch_stride_C0 = 0,
472
+ int64_t batch_stride_B1 = 0,
473
+ int64_t batch_stride_C1 = 0,
474
+ int64_t batch_stride_D1 = 0,
475
+ int64_t batch_stride_Bias0 = 0,
476
+ int64_t batch_stride_Scale0 = 0,
477
+ bool relu = true,
478
+ int warm_ups = 1,
479
+ int runs = 100) {
480
+
481
+ //
482
+ // Allocate the GEMM workspace
483
+ //
484
+
485
+ cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
486
+ cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
487
+ cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k());
488
+ cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k());
489
+ cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k());
490
+
491
+ cutlass::HostTensor<
492
+ typename B2bGemm::ElementA,
493
+ typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk());
494
+
495
+ cutlass::HostTensor<
496
+ typename B2bGemm::ElementB,
497
+ typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn());
498
+
499
+ cutlass::HostTensor<
500
+ typename B2bGemm::ElementC,
501
+ typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn());
502
+
503
+ cutlass::HostTensor<
504
+ typename B2bGemm::ElementScaleBias,
505
+ typename B2bGemm::LayoutScaleBias> tensor_Scale0;
506
+
507
+ if(alpha0 == ElementCompute(0)) //per-channel scale
508
+ tensor_Scale0.resize({1, batch_count * problem_size_0.n()});
509
+
510
+ cutlass::HostTensor<
511
+ typename B2bGemm::ElementScaleBias,
512
+ typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()});
513
+
514
+ cutlass::HostTensor<
515
+ ElementAccumulator,
516
+ typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn());
517
+
518
+ cutlass::HostTensor<
519
+ typename B2bGemm::ElementC,
520
+ typename B2bGemm::LayoutC> reference_D0(CoordC0.mn());
521
+
522
+ cutlass::HostTensor<
523
+ typename B2bGemm::ElementB,
524
+ typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn());
525
+
526
+ cutlass::HostTensor<
527
+ typename B2bGemm::ElementC,
528
+ typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn());
529
+
530
+ cutlass::HostTensor<
531
+ typename B2bGemm::ElementC,
532
+ typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()});
533
+
534
+ cutlass::HostTensor<
535
+ typename B2bGemm::ElementC,
536
+ typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn());
537
+
538
+ cutlass::HostTensor<
539
+ typename B2bGemm::ElementC,
540
+ typename B2bGemm::LayoutC> reference_D1(CoordC1.mn());
541
+
542
+
543
+ CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
544
+ CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
545
+ CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
546
+ if(alpha0 == ElementCompute(0)) //per-channel scale
547
+ CHECK_TRUE(initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed + 2014));
548
+ CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2013));
549
+ CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
550
+ CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
551
+ CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012));
552
+
553
+ cutlass::reference::host::TensorFill(
554
+ tensor_D1.host_view());
555
+ cutlass::reference::host::TensorFill(
556
+ reference_D0.host_view());
557
+ cutlass::reference::host::TensorFill(
558
+ reference_D1.host_view());
559
+
560
+ tensor_A0.sync_device();
561
+ tensor_B0.sync_device();
562
+ tensor_C0.sync_device();
563
+ if(alpha0 == ElementCompute(0)) //per-channel scale
564
+ tensor_Scale0.sync_device();
565
+ tensor_Bias0.sync_device();
566
+ tensor_B1.sync_device();
567
+ tensor_C1.sync_device();
568
+ tensor_Bias1.sync_device();
569
+ tensor_D1.sync_device();
570
+ reference_D0.sync_device();
571
+ reference_D1.sync_device();
572
+
573
+ //
574
+ // Initialize the GEMM operator
575
+ //
576
+
577
+ typename B2bGemm::Arguments arguments{
578
+ mode,
579
+ problem_size_0,
580
+ problem_size_1,
581
+ tensor_A0.device_ref(),
582
+ tensor_B0.device_ref(),
583
+ tensor_C0.device_ref(),
584
+ tensor_Scale0.device_ref(),
585
+ tensor_Bias0.device_ref(),
586
+ tensor_B1.device_ref(),
587
+ {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
588
+ tensor_D1.device_ref(),
589
+ batch_stride_A0,
590
+ batch_stride_B0,
591
+ batch_stride_B1,
592
+ batch_stride_C1,
593
+ batch_stride_D1,
594
+ batch_stride_Bias0,
595
+ batch_stride_Scale0,
596
+ {alpha0, beta0},
597
+ {alpha1, beta1},
598
+ batch_count,
599
+ };
600
+
601
+ B2bGemm b2b_gemm_op;
602
+
603
+ cutlass::Status status = b2b_gemm_op.can_implement(arguments);
604
+
605
+ if(status != cutlass::Status::kSuccess) {
606
+ std::cout << "Problem sizes not supported.\n"
607
+ << "Requirments:\n"
608
+ << " problem_size_0.M = problem_size_1.M\n"
609
+ << " problem_size_0.N = problem_size_1.K\n"
610
+ << " ThreadblockShape0::kN = problem_size_0.N\n"
611
+ << " ThreadblockShape1::kN = problem_size_1.N" << std::endl;
612
+ }
613
+
614
+ status = b2b_gemm_op.initialize(arguments);
615
+
616
+ CUTLASS_CHECK(status);
617
+
618
+ for(int i = 0; i < warm_ups; i++) {
619
+ status = b2b_gemm_op();
620
+ CUTLASS_CHECK(status);
621
+ }
622
+
623
+ //
624
+ // Run the GEMM
625
+ //
626
+
627
+ cudaEvent_t start, stop;
628
+ cudaEventCreate(&start);
629
+ cudaEventCreate(&stop);
630
+
631
+ cudaEventRecord(start);
632
+
633
+ for(int i = 0; i < runs; i++) {
634
+ status = b2b_gemm_op();
635
+
636
+ CUTLASS_CHECK(status);
637
+ }
638
+
639
+ cudaEventRecord(stop);
640
+ cudaDeviceSynchronize();
641
+ float gemmTime;
642
+ cudaEventElapsedTime(&gemmTime, start, stop);
643
+ std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
644
+
645
+ tensor_D1.sync_host();
646
+
647
+ //
648
+ // Verify
649
+ //
650
+
651
+ cutlass::reference::device::GemmComplex<
652
+ typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
653
+ typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
654
+ ElementAccumulator, typename B2bGemm::LayoutC,
655
+ ElementAccumulator, ElementAccumulator
656
+ >(
657
+
658
+ problem_size_0,
659
+ ElementAccumulator(1), //intermediate alpha=1
660
+ tensor_A0.device_ref(),
661
+ cutlass::ComplexTransform::kNone,
662
+ tensor_B0.device_ref(),
663
+ cutlass::ComplexTransform::kNone,
664
+ ElementAccumulator(0), //beta = 0
665
+ reference_Z0.device_ref(),
666
+ reference_Z0.device_ref(),
667
+ ElementAccumulator(0),
668
+ int(batch_count),
669
+ batch_stride_A0,
670
+ batch_stride_B0,
671
+ batch_stride_C0,
672
+ batch_stride_C0
673
+ );
674
+
675
+ cutlass::reference::device::TensorScaleBiasGemmBatched<
676
+ ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
677
+ ElementCompute, typename B2bGemm::LayoutScaleBias
678
+ > (
679
+ problem_size_0,
680
+ reference_Z0.device_ref(),
681
+ reference_D0.device_ref(),
682
+ alpha0,
683
+ tensor_Scale0.device_ref(),
684
+ tensor_Bias0.device_ref(),
685
+ int(batch_count),
686
+ batch_stride_C0,
687
+ batch_stride_C0,
688
+ batch_stride_Scale0,
689
+ batch_stride_Bias0
690
+ );
691
+
692
+ if(relu) {
693
+ cutlass::reference::device::TensorReLu(reference_D0.device_view());
694
+ }
695
+
696
+ cutlass::reference::device::GemmComplex<
697
+ typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
698
+ typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
699
+ typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
700
+ ElementCompute, ElementAccumulator
701
+ >(
702
+ problem_size_1,
703
+ alpha1, //intermediate alpha=1
704
+ reference_D0.device_ref(),
705
+ cutlass::ComplexTransform::kNone,
706
+ tensor_B1.device_ref(),
707
+ cutlass::ComplexTransform::kNone,
708
+ beta1, //beta = 0
709
+ {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
710
+ reference_D1.device_ref(),
711
+ ElementAccumulator(0),
712
+ int(batch_count),
713
+ batch_stride_C0,
714
+ batch_stride_B1,
715
+ batch_stride_C1,
716
+ batch_stride_D1
717
+ );
718
+
719
+ if(relu) {
720
+ cutlass::reference::device::TensorReLu(reference_D1.device_view());
721
+ }
722
+
723
+ cudaDeviceSynchronize();
724
+ reference_D0.sync_host();
725
+ reference_D1.sync_host();
726
+
727
+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
728
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
729
+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
730
+
731
+ bool passed = cutlass::reference::host::TensorEquals(
732
+ reference_D1.host_view(),
733
+ tensor_D1.host_view());
734
+
735
+ CHECK_TRUE(passed);
736
+ if (!passed)
737
+ {
738
+
739
+ std::stringstream fname;
740
+
741
+ fname << "error_B2bGemm_device_fused.txt";
742
+ std::cerr << "Dumping results in " << fname.str() << "\n";
743
+
744
+ std::ofstream file(fname.str());
745
+
746
+ file
747
+ << "A0 =\n" << tensor_A0.host_view()
748
+ << "\nB0 =\n" << tensor_B0.host_view()
749
+ << "\nC0 =\n" << tensor_C0.host_view()
750
+ << "\nScale0:\n" << tensor_Scale0.host_view() << "\n"
751
+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
752
+ << "\nB1 =\n" << tensor_B1.host_view()
753
+ << "\nC1 =\n" << tensor_C1.host_view()
754
+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
755
+ << "\n\nReference =\n" << reference_D1.host_view()
756
+ << "\nComputed =\n" << tensor_D1.host_view();
757
+ }
758
+ return passed;
759
+ }
760
+
761
+ };
762
+
763
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Containers for running grouped back-to-back GEMMs
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include <iostream>
38
+ #include <fstream>
39
+ #include <sstream>
40
+
41
+ #include "cutlass/util/device_memory.h"
42
+ #include "cutlass/util/host_tensor.h"
43
+ #include "cutlass/util/tensor_view_io.h"
44
+ #include "cutlass/util/distribution.h"
45
+ #include "cutlass/util/reference/host/tensor_fill.h"
46
+ #include "cutlass/util/reference/host/tensor_copy.h"
47
+ #include "cutlass/util/reference/host/tensor_compare.h"
48
+ #include "cutlass/util/reference/host/tensor_norm.h"
49
+ #include "cutlass/util/reference/device/gemm.h"
50
+ #include "cutlass/util/reference/device/tensor_relu.h"
51
+
52
+ #include "reference/device/tensor_scale_bias.h"
53
+ #include "helper.h"
54
+
55
+ #define CHECK_GT(val1, val2) \
56
+ if((val1) <= (val2)) \
57
+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
58
+ #define CHECK_TRUE(val) \
59
+ if(!(val)) \
60
+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
61
+
62
+ ////////////////////////////////////////////////////////////////////////////////
63
+
64
+ template <typename B2bGemm_>
65
+ struct B2bFusedGroupedGemmRun
66
+ {
67
+
68
+ using B2bGemm = B2bGemm_;
69
+ using ElementAccumulator = typename B2bGemm::ElementAccumulator;
70
+ using ElementCompute = typename B2bGemm::BaseKernel::Epilogue::OutputOp::ElementCompute;
71
+
72
+ /// Initialization
73
+ cutlass::Distribution::Kind init_A;
74
+ cutlass::Distribution::Kind init_B;
75
+ cutlass::Distribution::Kind init_C;
76
+ cutlass::Distribution::Kind init_Scale;
77
+ cutlass::Distribution::Kind init_Bias;
78
+ uint64_t seed;
79
+
80
+ //
81
+ // Methods
82
+ //
83
+
84
+ B2bFusedGroupedGemmRun(
85
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
86
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
87
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
88
+ cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
89
+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
90
+ uint64_t seed_ = 2080
91
+ ):
92
+ init_A(init_A_), init_B(init_B_), init_C(init_C_),
93
+ init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
94
+
95
+ /// Helper to initialize a tensor view
96
+ template <typename Element, typename Layout>
97
+ bool initialize_tensor(
98
+ cutlass::TensorView<Element, Layout> view,
99
+ cutlass::Distribution::Kind dist_kind,
100
+ uint64_t seed) {
101
+
102
+ if (dist_kind == cutlass::Distribution::Uniform) {
103
+
104
+ cutlass::reference::host::TensorFillRandomUniform(
105
+ view, seed, 1, -1, 0);
106
+ }
107
+ else if (dist_kind == cutlass::Distribution::Identity) {
108
+
109
+ cutlass::reference::host::TensorFillIdentity(view);
110
+ }
111
+ else if (dist_kind == cutlass::Distribution::Gaussian) {
112
+
113
+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
114
+ }
115
+ else if (dist_kind == cutlass::Distribution::Sequential) {
116
+
117
+ cutlass::reference::host::BlockFillSequential(
118
+ view.data(), view.capacity());
119
+ }
120
+ else if (dist_kind == cutlass::Distribution::AllZeros) {
121
+ cutlass::reference::host::TensorFill(view, Element(0));
122
+ }
123
+ else if (dist_kind == cutlass::Distribution::AllOnes) {
124
+ cutlass::reference::host::TensorFill(view, Element(1));
125
+ }
126
+ else {
127
+ std::cerr << "Not implemented\n";
128
+ return false;
129
+ }
130
+
131
+ return true;
132
+ }
133
+
134
+ /// Executes one test
135
+ bool run(
136
+ std::vector<cutlass::gemm::GemmCoord> problem_sizes_0,
137
+ std::vector<cutlass::gemm::GemmCoord> problem_sizes_1,
138
+ ElementCompute alpha0 = ElementCompute(1),
139
+ ElementCompute beta0 = ElementCompute(0),
140
+ ElementCompute alpha1 = ElementCompute(1),
141
+ ElementCompute beta1 = ElementCompute(0),
142
+ bool relu = true,
143
+ int warm_ups = 1,
144
+ int runs = 100) {
145
+
146
+ using HostTensorA = cutlass::HostTensor<typename B2bGemm::ElementA, typename B2bGemm::LayoutA>;
147
+ using HostTensorB = cutlass::HostTensor<typename B2bGemm::ElementB, typename B2bGemm::LayoutB>;
148
+ using HostTensorC = cutlass::HostTensor<typename B2bGemm::ElementC, typename B2bGemm::LayoutC>;
149
+ using HostTensorScale = cutlass::HostTensor<ElementCompute, typename B2bGemm::LayoutC>;
150
+ using HostTensorZ = cutlass::HostTensor<ElementAccumulator, typename B2bGemm::LayoutC>;
151
+ using HostTensorBias = cutlass::HostTensor<ElementCompute, typename B2bGemm::LayoutC>;
152
+
153
+ int problem_count = (int)problem_sizes_0.size();
154
+
155
+ std::vector<HostTensorA> host_tensor_A0(problem_count);
156
+ std::vector<HostTensorB> host_tensor_B0(problem_count);
157
+ std::vector<HostTensorC> host_tensor_C0(problem_count);
158
+ std::vector<HostTensorScale> host_tensor_Scale0(problem_count);
159
+ std::vector<HostTensorScale> host_tensor_Bias0(problem_count);
160
+ std::vector<HostTensorB> host_tensor_B1(problem_count);
161
+ std::vector<HostTensorC> host_tensor_C1(problem_count);
162
+ std::vector<HostTensorBias> host_tensor_Bias1(problem_count);
163
+ std::vector<HostTensorC> host_tensor_D1(problem_count);
164
+ std::vector<HostTensorZ> host_tensor_Z(problem_count);
165
+ std::vector<HostTensorC> host_tensor_ref_D0(problem_count);
166
+ std::vector<HostTensorC> host_tensor_ref_D1(problem_count);
167
+
168
+ std::vector<typename HostTensorA::TensorRef> ref_A0(problem_count);
169
+ std::vector<typename HostTensorB::TensorRef> ref_B0(problem_count);
170
+ std::vector<typename HostTensorC::TensorRef> ref_C0(problem_count);
171
+ std::vector<typename HostTensorScale::TensorRef> ref_Scale0(problem_count);
172
+ std::vector<typename HostTensorScale::TensorRef> ref_Bias0(problem_count);
173
+ std::vector<typename HostTensorB::TensorRef> ref_B1(problem_count);
174
+ std::vector<typename HostTensorC::TensorRef> ref_C1(problem_count);
175
+ std::vector<typename HostTensorBias::TensorRef> ref_Bias1(problem_count);
176
+ std::vector<typename HostTensorC::TensorRef> ref_D1(problem_count);
177
+ std::vector<typename HostTensorZ::TensorRef> ref_Z(problem_count);
178
+ std::vector<typename HostTensorC::TensorRef> ref_ref_D0(problem_count);
179
+ std::vector<typename HostTensorC::TensorRef> ref_ref_D1(problem_count);
180
+
181
+ for (int i = 0; i < problem_count; ++i) {
182
+ //
183
+ // Allocate the GEMM workspace
184
+ //
185
+
186
+ auto problem_size_0 = problem_sizes_0[i];
187
+ auto problem_size_1 = problem_sizes_1[i];
188
+
189
+ host_tensor_A0.at(i) = HostTensorA(problem_size_0.mk());
190
+ host_tensor_B0.at(i) = HostTensorB(problem_size_0.kn());
191
+ host_tensor_C0.at(i) = HostTensorC(problem_size_0.mn());
192
+ if (alpha0 == ElementCompute(0)) //per-channel scale
193
+ host_tensor_Scale0.at(i) = HostTensorScale(typename HostTensorZ::Layout::TensorCoord{1, problem_size_0.n()});
194
+ host_tensor_Bias0.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_0.n()});
195
+ host_tensor_Z.at(i) = HostTensorZ(problem_size_0.mn());
196
+ host_tensor_ref_D0.at(i) = HostTensorC(problem_size_0.mn());
197
+ host_tensor_B1.at(i) = HostTensorB(problem_size_1.kn());
198
+ host_tensor_C1.at(i) = HostTensorC(problem_size_1.mn());
199
+ host_tensor_Bias1.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_1.n()});
200
+ host_tensor_D1.at(i) = HostTensorC(problem_size_1.mn());
201
+ host_tensor_ref_D1.at(i) = HostTensorC(problem_size_1.mn());
202
+
203
+ CHECK_TRUE(initialize_tensor(host_tensor_A0.at(i).host_view(), init_A, seed + 2019));
204
+ CHECK_TRUE(initialize_tensor(host_tensor_B0.at(i).host_view(), init_B, seed + 2018));
205
+ CHECK_TRUE(initialize_tensor(host_tensor_C0.at(i).host_view(), init_C, seed + 2017));
206
+ if (alpha0 == ElementCompute(0)) //per-channel scale
207
+ CHECK_TRUE(initialize_tensor(host_tensor_Scale0.at(i).host_view(), init_Scale, seed + 2014));
208
+ CHECK_TRUE(initialize_tensor(host_tensor_Bias0.at(i).host_view(), init_Bias, seed + 2013));
209
+ CHECK_TRUE(initialize_tensor(host_tensor_B1.at(i).host_view(), init_B, seed + 2016));
210
+ CHECK_TRUE(initialize_tensor(host_tensor_C1.at(i).host_view(), init_C, seed + 2015));
211
+ CHECK_TRUE(initialize_tensor(host_tensor_Bias1.at(i).host_view(), init_Bias, seed + 2012));
212
+
213
+ cutlass::reference::host::TensorFill(
214
+ host_tensor_D1.at(i).host_view());
215
+ cutlass::reference::host::TensorFill(
216
+ host_tensor_ref_D0.at(i).host_view());
217
+ cutlass::reference::host::TensorFill(
218
+ host_tensor_ref_D1.at(i).host_view());
219
+
220
+ host_tensor_A0.at(i).sync_device();
221
+ host_tensor_B0.at(i).sync_device();
222
+ host_tensor_C0.at(i).sync_device();
223
+ if (alpha0 == ElementCompute(0)) //per-channel scale
224
+ host_tensor_Scale0.at(i).sync_device();
225
+ host_tensor_Bias0.at(i).sync_device();
226
+ host_tensor_B1.at(i).sync_device();
227
+ host_tensor_C1.at(i).sync_device();
228
+ host_tensor_Bias1.at(i).sync_device();
229
+ host_tensor_D1.at(i).sync_device();
230
+ host_tensor_ref_D0.at(i).sync_device();
231
+ host_tensor_ref_D1.at(i).sync_device();
232
+
233
+ ref_A0.at(i) = (host_tensor_A0.at(i).device_ref());
234
+ ref_B0.at(i) = (host_tensor_B0.at(i).device_ref());
235
+ ref_C0.at(i) = (host_tensor_C0.at(i).device_ref());
236
+ if (alpha0 == ElementCompute(0)) //per-channel scale
237
+ ref_Scale0.at(i) = (host_tensor_Scale0.at(i).device_ref());
238
+ ref_Bias0.at(i) = (host_tensor_Bias0.at(i).device_ref());
239
+ ref_B1.at(i) = (host_tensor_B1.at(i).device_ref());
240
+ ref_C1.at(i) = {host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)};
241
+ ref_Bias1.at(i) = (host_tensor_Bias1.at(i).device_ref());
242
+ ref_D1.at(i) = (host_tensor_D1.at(i).device_ref());
243
+ ref_Z.at(i) = (host_tensor_Z.at(i).device_ref());
244
+ ref_ref_D0.at(i) = (host_tensor_ref_D0.at(i).device_ref());
245
+ ref_ref_D1.at(i) = (host_tensor_ref_D1.at(i).device_ref());
246
+ }
247
+
248
+ //
249
+ // Initialize the GEMM operator
250
+ //
251
+
252
+ cutlass::DeviceAllocation<typename HostTensorA::TensorRef> device_ref_A0(problem_count);
253
+ device_ref_A0.copy_from_host(ref_A0.data());
254
+ cutlass::DeviceAllocation<typename HostTensorB::TensorRef> device_ref_B0(problem_count);
255
+ device_ref_B0.copy_from_host(ref_B0.data());
256
+ cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_C0(problem_count);
257
+ device_ref_C0.copy_from_host(ref_C0.data());
258
+ cutlass::DeviceAllocation<typename HostTensorScale::TensorRef> device_ref_Scale0(problem_count);
259
+ device_ref_Scale0.copy_from_host(ref_Scale0.data());
260
+ cutlass::DeviceAllocation<typename HostTensorScale::TensorRef> device_ref_Bias0(problem_count);
261
+ device_ref_Bias0.copy_from_host(ref_Bias0.data());
262
+ cutlass::DeviceAllocation<typename HostTensorB::TensorRef> device_ref_B1(problem_count);
263
+ device_ref_B1.copy_from_host(ref_B1.data());
264
+ cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_C1(problem_count);
265
+ device_ref_C1.copy_from_host(ref_C1.data());
266
+ cutlass::DeviceAllocation<typename HostTensorBias::TensorRef> device_ref_Bias1(problem_count);
267
+ device_ref_Bias1.copy_from_host(ref_Bias1.data());
268
+ cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_D1(problem_count);
269
+ device_ref_D1.copy_from_host(ref_D1.data());
270
+
271
+ cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> device_problem_sizes_0(problem_count);
272
+ device_problem_sizes_0.copy_from_host(problem_sizes_0.data());
273
+ cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> device_problem_sizes_1(problem_count);
274
+ device_problem_sizes_1.copy_from_host(problem_sizes_1.data());
275
+
276
+ B2bGemm b2b_gemm_op;
277
+
278
+ int threadblock_count = B2bGemm::sufficient(problem_sizes_1.data(), problem_count);
279
+ if (!threadblock_count) {
280
+ std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." << std::endl;
281
+ return false;
282
+ }
283
+
284
+ typename B2bGemm::Arguments arguments{
285
+ problem_count,
286
+ device_problem_sizes_0.get(),
287
+ device_problem_sizes_1.get(),
288
+ device_ref_A0.get(),
289
+ device_ref_B0.get(),
290
+ device_ref_C0.get(),
291
+ device_ref_Scale0.get(),
292
+ device_ref_Bias0.get(),
293
+ device_ref_B1.get(),
294
+ device_ref_C1.get(),
295
+ device_ref_D1.get(),
296
+ {alpha0, beta0},
297
+ {alpha1, beta1},
298
+ threadblock_count
299
+ };
300
+
301
+ cutlass::Status status = b2b_gemm_op.can_implement(arguments);
302
+
303
+ if(status != cutlass::Status::kSuccess) {
304
+ std::cout << "Problem sizes not supported.\n"
305
+ << "Requirments:\n"
306
+ << " problem_size_0.M = problem_size_1.M\n"
307
+ << " problem_size_0.N = problem_size_1.K\n"
308
+ << " ThreadblockShape0::kN = problem_size_0.N\n"
309
+ << " ThreadblockShape1::kN = problem_size_1.N" << std::endl;
310
+ }
311
+
312
+ status = b2b_gemm_op.initialize(arguments);
313
+
314
+ CUTLASS_CHECK(status);
315
+
316
+ for(int i = 0; i < warm_ups; i++) {
317
+ status = b2b_gemm_op();
318
+ CUTLASS_CHECK(status);
319
+ }
320
+
321
+ //
322
+ // Run the GEMM
323
+ //
324
+
325
+ cudaEvent_t start, stop;
326
+ cudaEventCreate(&start);
327
+ cudaEventCreate(&stop);
328
+
329
+ cudaEventRecord(start);
330
+
331
+ for(int i = 0; i < runs; i++) {
332
+ status = b2b_gemm_op();
333
+ CUTLASS_CHECK(status);
334
+ }
335
+
336
+ cudaEventRecord(stop);
337
+ cudaDeviceSynchronize();
338
+ float gemmTime;
339
+ cudaEventElapsedTime(&gemmTime, start, stop);
340
+ std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
341
+
342
+ for (int i = 0; i < problem_count; ++i) {
343
+ host_tensor_D1.at(i).sync_host();
344
+
345
+ //
346
+ // Verify
347
+ //
348
+
349
+ cutlass::reference::device::Gemm<
350
+ typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
351
+ typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
352
+ ElementAccumulator, typename B2bGemm::LayoutC,
353
+ ElementAccumulator, ElementAccumulator>
354
+ reference_gemm_0;
355
+
356
+ cutlass::reference::device::Gemm<
357
+ typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
358
+ typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
359
+ typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
360
+ ElementAccumulator>
361
+ reference_gemm_1;
362
+
363
+ auto problem_size_0 = problem_sizes_0[i];
364
+ auto problem_size_1 = problem_sizes_1[i];
365
+
366
+ reference_gemm_0(
367
+ problem_size_0,
368
+ ElementAccumulator(1), //intermediate alpha=1
369
+ ref_A0.at(i),
370
+ ref_B0.at(i),
371
+ ElementAccumulator(0), //beta = 0
372
+ ref_Z.at(i),
373
+ ref_Z.at(i),
374
+ ElementAccumulator(0)
375
+ );
376
+
377
+ cutlass::reference::device::TensorScaleBiasGemm<
378
+ ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
379
+ ElementCompute, typename B2bGemm::LayoutC
380
+ > (
381
+ problem_size_0,
382
+ ref_Z.at(i),
383
+ ref_ref_D0.at(i),
384
+ alpha0,
385
+ ref_Scale0.at(i),
386
+ ref_Bias0.at(i)
387
+ );
388
+
389
+ if(relu) {
390
+ cutlass::reference::device::TensorReLu(host_tensor_ref_D0.at(i).device_view());
391
+ }
392
+
393
+ reference_gemm_1(
394
+ problem_size_1,
395
+ alpha1,
396
+ ref_ref_D0.at(i),
397
+ ref_B1.at(i),
398
+ beta1,
399
+ {host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)},
400
+ ref_ref_D1.at(i)
401
+ );
402
+ if(relu) {
403
+ cutlass::reference::device::TensorReLu(host_tensor_ref_D1.at(i).device_view());
404
+ }
405
+ cudaDeviceSynchronize();
406
+ host_tensor_ref_D0.at(i).sync_host();
407
+ host_tensor_ref_D1.at(i).sync_host();
408
+
409
+ CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D0.at(i).host_view()), 0);
410
+ CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_D1.at(i).host_view()), 0);
411
+ CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D1.at(i).host_view()), 0);
412
+
413
+ bool passed = cutlass::reference::host::TensorEquals(
414
+ host_tensor_ref_D1.at(i).host_view(),
415
+ host_tensor_D1.at(i).host_view());
416
+
417
+ CHECK_TRUE(passed);
418
+ if (!passed)
419
+ {
420
+
421
+ std::stringstream fname;
422
+
423
+ fname << "error_B2bGemm_device_fused.txt";
424
+ std::cerr << "Check failed for GEMM " << i << " in the group." << std::endl;
425
+ std::cerr << "Dumping results in " << fname.str() << "\n";
426
+
427
+ std::ofstream file(fname.str());
428
+
429
+ file
430
+ << "GEMM " << i << " in group\n"
431
+ << "A0 =\n" << host_tensor_A0.at(i).host_view()
432
+ << "\nB0 =\n" << host_tensor_B0.at(i).host_view()
433
+ << "\nC0 =\n" << host_tensor_C0.at(i).host_view()
434
+ << "\nScale0:\n" << host_tensor_Scale0.at(i).host_view() << "\n"
435
+ << "\nBias0:\n" << host_tensor_Bias0.at(i).host_view() << "\n"
436
+ << "\nB1 =\n" << host_tensor_B1.at(i).host_view()
437
+ << "\nC1 =\n" << host_tensor_C1.at(i).host_view()
438
+ << "\nBias1:\n" << host_tensor_Bias1.at(i).host_view() << "\n"
439
+ << "\n\nReference =\n" << host_tensor_ref_D1.at(i).host_view()
440
+ << "\nComputed =\n" << host_tensor_D1.at(i).host_view();
441
+
442
+ return false;
443
+ }
444
+ }
445
+ return true;
446
+ }
447
+
448
+ };
449
+
450
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h ADDED
@@ -0,0 +1,749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ #pragma once
33
+
34
+ #include <iostream>
35
+ #include <fstream>
36
+ #include <sstream>
37
+
38
+ #include "cutlass/cutlass.h"
39
+
40
+ #include "cutlass/conv/device/implicit_gemm_convolution.h"
41
+ #include "cutlass/reduction/device/reduce_split_k.h"
42
+ #include "cutlass/reduction/thread/reduction_operators.h"
43
+
44
+ #include "cutlass/util/host_tensor.h"
45
+ #include "cutlass/util/reference/host/tensor_fill.h"
46
+ #include "cutlass/util/reference/device/tensor_compare.h"
47
+ #include "cutlass/util/reference/host/tensor_compare.h"
48
+ #include "cutlass/util/reference/host/tensor_norm.h"
49
+ #include "cutlass/util/host_reorder.h"
50
+
51
+ #include "cutlass/util/reference/host/convolution.h"
52
+ #include "cutlass/util/reference/device/convolution.h"
53
+ #include "cutlass/util/reference/device/tensor_relu.h"
54
+
55
+ #include "cutlass/core_io.h"
56
+ #include "cutlass/util/tensor_view_io.h"
57
+
58
+ #include "reference/device/tensor_scale_bias.h"
59
+ #include "helper.h"
60
+
61
+ #define CHECK_GT(val1, val2) \
62
+ if((val1) <= (val2)) \
63
+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
64
+ #define CHECK_TRUE(val) \
65
+ if(!(val)) \
66
+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
67
+
68
+
69
+ template <typename Conv2d0_, typename Conv2d1_, int InterleavedK>
70
+ class B2bInterleavedNonFusedConv2dRun {
71
+ public:
72
+
73
+ using Conv2d0 = Conv2d0_;
74
+ using Conv2d1 = Conv2d1_;
75
+ using ElementAccumulator = typename Conv2d0::ElementAccumulator;
76
+ using ElementCompute = typename Conv2d0::ElementCompute;
77
+
78
+ static cutlass::conv::Operator const kConvolutionalOperator = Conv2d0::kConvolutionalOperator;
79
+ static_assert(kConvolutionalOperator == Conv2d1::kConvolutionalOperator,
80
+ "Fused convolution operators must be the same");
81
+
82
+ public:
83
+
84
+ /// Initialization
85
+ cutlass::Distribution::Kind init_A;
86
+ cutlass::Distribution::Kind init_B;
87
+ cutlass::Distribution::Kind init_C;
88
+ cutlass::Distribution::Kind init_Bias;
89
+ uint64_t seed;
90
+
91
+ cutlass::HostTensor<typename Conv2d0::ElementA, typename Conv2d0::LayoutA> tensor_A0;
92
+ cutlass::HostTensor<typename Conv2d0::ElementB, typename Conv2d0::LayoutB> tensor_B0;
93
+ cutlass::HostTensor<typename Conv2d0::ElementB, typename Conv2d0::LayoutB> tensor_B0_reordered;
94
+ cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_C0;
95
+ cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_Bias0;
96
+ cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_computed;
97
+ cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_reference;
98
+
99
+ cutlass::HostTensor<typename Conv2d1::ElementB, typename Conv2d1::LayoutB> tensor_B1;
100
+ cutlass::HostTensor<typename Conv2d1::ElementB, typename Conv2d1::LayoutB> tensor_B1_reordered;
101
+ cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_C1;
102
+ cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d0::LayoutC> tensor_Bias1;
103
+ cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_computed;
104
+ cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_reference;
105
+
106
+
107
+ public:
108
+
109
+ B2bInterleavedNonFusedConv2dRun(
110
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
111
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
112
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
113
+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
114
+ uint64_t seed_ = 2080
115
+ ):
116
+ init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) {
117
+
118
+ }
119
+
120
+ /// Helper to initialize a tensor view
121
+ template <typename Element, typename Layout>
122
+ void initialize_tensor(
123
+ cutlass::TensorView<Element, Layout> view,
124
+ cutlass::Distribution::Kind dist_kind,
125
+ uint64_t seed) {
126
+
127
+ if (dist_kind == cutlass::Distribution::Uniform) {
128
+
129
+ int scope;
130
+ int bits = cutlass::sizeof_bits<Element>::value;
131
+
132
+ if (bits <= 16) {
133
+ scope = 2;
134
+ }
135
+ else {
136
+ scope = 8;
137
+ }
138
+ cutlass::reference::host::TensorFillRandomUniform(
139
+ view, seed, scope, -scope, 0);
140
+ }
141
+ else if (dist_kind == cutlass::Distribution::Identity) {
142
+
143
+ cutlass::reference::host::TensorFillIdentity(view);
144
+ }
145
+ else if (dist_kind == cutlass::Distribution::Gaussian) {
146
+
147
+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
148
+ }
149
+ else if (dist_kind == cutlass::Distribution::Sequential) {
150
+
151
+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
152
+ }
153
+ else if (dist_kind == cutlass::Distribution::AllZeros) {
154
+ cutlass::reference::host::TensorFill(view, Element(0));
155
+ }
156
+ else if (dist_kind == cutlass::Distribution::AllOnes) {
157
+ cutlass::reference::host::TensorFill(view, Element(1));
158
+ }
159
+ else {
160
+ }
161
+ }
162
+
163
+ void initialize(
164
+ cutlass::conv::Conv2dProblemSize const &problem_size_0,
165
+ cutlass::conv::Conv2dProblemSize const &problem_size_1, uint64_t seed = 2019) {
166
+
167
+ tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0));
168
+ tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
169
+ tensor_B0_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
170
+ tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
171
+ tensor_Bias0.resize({1, 1, 1, problem_size_0.K});
172
+ tensor_D0_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
173
+ tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
174
+ tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
175
+ tensor_B1_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
176
+ tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
177
+ tensor_Bias1.resize({1, 1, 1, problem_size_1.K});
178
+ tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
179
+ tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
180
+
181
+ initialize_tensor(tensor_A0.host_view(), init_A, seed);
182
+ initialize_tensor(tensor_B0.host_view(), init_B, seed * 17);
183
+ initialize_tensor(tensor_C0.host_view(), init_C, seed * 39);
184
+ initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83);
185
+ initialize_tensor(tensor_B1.host_view(), init_B, seed * 18);
186
+ initialize_tensor(tensor_C1.host_view(), init_C, seed * 40);
187
+
188
+ //Reorder B0 and B1
189
+ cutlass::reorder_convK<InterleavedK, InterleavedK>(
190
+ tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_0));
191
+ cutlass::reorder_convK<InterleavedK, InterleavedK>(
192
+ tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_1));
193
+
194
+ tensor_A0.sync_device();
195
+ tensor_B0.sync_device();
196
+ tensor_B0_reordered.sync_device();
197
+ tensor_C0.sync_device();
198
+ tensor_Bias0.sync_device();
199
+ tensor_D0_computed.sync_device();
200
+ tensor_D0_reference.sync_device();
201
+ tensor_B1.sync_device();
202
+ tensor_B1_reordered.sync_device();
203
+ tensor_C1.sync_device();
204
+ tensor_Bias1.sync_device();
205
+ tensor_D1_computed.sync_device();
206
+ tensor_D1_reference.sync_device();
207
+ }
208
+
209
+ /// Executes one test
210
+ bool run(
211
+ cutlass::conv::Conv2dProblemSize const &problem_size_0,
212
+ cutlass::conv::Conv2dProblemSize const &problem_size_1,
213
+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
214
+ ElementCompute alpha0 = ElementCompute(1),
215
+ ElementCompute beta0 = ElementCompute(0),
216
+ ElementCompute alpha1 = ElementCompute(1),
217
+ ElementCompute beta1 = ElementCompute(0),
218
+ bool relu = true,
219
+ int warm_ups = 1,
220
+ int runs = 100) {
221
+
222
+ initialize(problem_size_0, problem_size_1);
223
+
224
+ // configure the operator
225
+ Conv2d0 conv2d_op_0;
226
+ Conv2d1 conv2d_op_1;
227
+
228
+ typename Conv2d0::Arguments conv2d_args_0(
229
+ problem_size_0,
230
+ tensor_A0.device_ref(),
231
+ tensor_B0_reordered.device_ref(),
232
+ tensor_C0.device_ref(),
233
+ tensor_D0_computed.device_ref(),
234
+ {alpha0, beta0},
235
+ split_k_mode
236
+ );
237
+ typename Conv2d1::Arguments conv2d_args_1(
238
+ problem_size_1,
239
+ tensor_D0_computed.device_ref(),
240
+ tensor_B1_reordered.device_ref(),
241
+ tensor_C1.device_ref(),
242
+ tensor_D1_computed.device_ref(),
243
+ {alpha1, beta1},
244
+ split_k_mode
245
+ );
246
+
247
+
248
+ cutlass::Status status = conv2d_op_0.initialize(conv2d_args_0);
249
+
250
+ CUTLASS_CHECK(status);
251
+
252
+ status = conv2d_op_1.initialize(conv2d_args_1);
253
+
254
+ CUTLASS_CHECK(status);
255
+
256
+ for(int i = 0; i < warm_ups; i++) {
257
+ status = conv2d_op_0();
258
+ CUTLASS_CHECK(status);
259
+ status = conv2d_op_1();
260
+ CUTLASS_CHECK(status);
261
+ }
262
+
263
+ //
264
+ // Run Conv2d
265
+ //
266
+ cudaEvent_t start, stop1, stop2;
267
+ cudaEventCreate(&start);
268
+ cudaEventCreate(&stop1);
269
+ cudaEventCreate(&stop2);
270
+
271
+ cudaEventRecord(start);
272
+
273
+
274
+ for(int i = 0; i < runs; i++) {
275
+ // run conv2d operator
276
+ status = conv2d_op_0();
277
+ CUTLASS_CHECK(status);
278
+ }
279
+ cudaEventRecord(stop1);
280
+
281
+ for(int i = 0; i < runs; i++) {
282
+ // run conv2d operator
283
+ status = conv2d_op_1();
284
+ CUTLASS_CHECK(status);
285
+ }
286
+ cudaEventRecord(stop2);
287
+ cudaDeviceSynchronize();
288
+ float conv2d0Time, conv2d1Time, totalTime;
289
+ cudaEventElapsedTime(&conv2d0Time, start, stop1);
290
+ cudaEventElapsedTime(&conv2d1Time, stop1, stop2);
291
+ cudaEventElapsedTime(&totalTime, start, stop2);
292
+ std::cout << "conv2d 0 time " << conv2d0Time / (float)runs << " ms\n";
293
+ std::cout << "conv2d 1 time " << conv2d1Time / (float)runs << " ms\n";
294
+ std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n";
295
+
296
+ tensor_D0_computed.sync_host();
297
+ tensor_D1_computed.sync_host();
298
+
299
+ bool passed = false;
300
+
301
+ cutlass::reference::device::Conv2d<
302
+ typename Conv2d0::ElementA,
303
+ typename Conv2d0::LayoutA,
304
+ typename Conv2d0::ElementB,
305
+ typename Conv2d0::LayoutB,
306
+ typename Conv2d0::ElementC,
307
+ typename Conv2d0::LayoutC,
308
+ ElementCompute,
309
+ ElementAccumulator,
310
+ cutlass::NumericConverterClamp<typename Conv2d0::ElementC, ElementCompute>
311
+ >(
312
+ kConvolutionalOperator,
313
+ problem_size_0,
314
+ tensor_A0.device_ref(),
315
+ tensor_B0.device_ref(),
316
+ tensor_C0.device_ref(),
317
+ tensor_D0_reference.device_ref(),
318
+ alpha0,
319
+ beta0);
320
+
321
+ if(relu) {
322
+ cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view());
323
+ }
324
+
325
+ cutlass::reference::device::Conv2d<
326
+ typename Conv2d1::ElementA,
327
+ typename Conv2d1::LayoutA,
328
+ typename Conv2d1::ElementB,
329
+ typename Conv2d1::LayoutB,
330
+ typename Conv2d1::ElementC,
331
+ typename Conv2d1::LayoutC,
332
+ ElementCompute,
333
+ ElementAccumulator,
334
+ cutlass::NumericConverterClamp<typename Conv2d1::ElementC, ElementCompute>
335
+ >(
336
+ kConvolutionalOperator,
337
+ problem_size_1,
338
+ tensor_D0_reference.device_ref(),
339
+ tensor_B1.device_ref(),
340
+ tensor_C1.device_ref(),
341
+ tensor_D1_reference.device_ref(),
342
+ alpha1,
343
+ beta1);
344
+
345
+ if(relu) {
346
+ cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view());
347
+ }
348
+
349
+ cudaError_t result = cudaDeviceSynchronize();
350
+ CHECK_TRUE(result == cudaSuccess);
351
+
352
+ // sync host (copy device data to host) for dumping error output in case of mismatches
353
+ tensor_D0_reference.sync_host();
354
+ tensor_D1_reference.sync_host();
355
+
356
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_computed.host_view()), 0);
357
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0);
358
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0);
359
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0);
360
+
361
+ passed = cutlass::reference::host::TensorEquals(
362
+ tensor_D1_computed.host_view(),
363
+ tensor_D1_reference.host_view());
364
+
365
+ CHECK_TRUE(passed);
366
+
367
+ if (!passed) {
368
+ std::stringstream fname;
369
+
370
+ fname << "error_B2bImplicitGemm_device_interleaved_nonfused.txt";
371
+ std::cerr << "Dumping results in " << fname.str() << "\n";
372
+
373
+ std::ofstream results(fname.str());
374
+
375
+ results << problem_size_0 << std::endl;
376
+ results << problem_size_1 << std::endl;
377
+
378
+ results
379
+ << "\nA0:\n" << tensor_A0.host_view() << "\n"
380
+ << "\nB0:\n" << tensor_B0.host_view() << "\n"
381
+ << "\nB0_reordered:\n" << tensor_B0_reordered.host_view() << "\n"
382
+ << "\nC0:\n" << tensor_C0.host_view() << "\n"
383
+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
384
+ << "\nD0 reference:\n" << tensor_D0_reference.host_view() << "\n"
385
+ << "\nD0 computed:\n" << tensor_D0_computed.host_view() << "\n"
386
+ << "\nB1:\n" << tensor_B1.host_view() << "\n"
387
+ << "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n"
388
+ << "\nC1:\n" << tensor_C1.host_view() << "\n"
389
+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
390
+ << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n"
391
+ << "\nD1 computed:\n" << tensor_D1_computed.host_view();
392
+
393
+
394
+ }
395
+
396
+ return passed;
397
+ }
398
+
399
+ };
400
+
401
+ template <typename B2bConv2d_, int InterleavedK>
402
+ class B2bInterleavedFusedConv2dRun {
403
+ public:
404
+
405
+ using B2bConv2d = B2bConv2d_;
406
+ using ElementAccumulator = typename B2bConv2d::ElementAccumulator;
407
+ using ElementCompute = typename B2bConv2d::ElementCompute;
408
+
409
+ static cutlass::conv::Operator const kConvolutionalOperator = B2bConv2d::kConvolutionalOperator;
410
+
411
+ public:
412
+
413
+ /// Initialization
414
+ cutlass::Distribution::Kind init_A;
415
+ cutlass::Distribution::Kind init_B;
416
+ cutlass::Distribution::Kind init_C;
417
+ cutlass::Distribution::Kind init_Scale;
418
+ cutlass::Distribution::Kind init_Bias;
419
+ uint64_t seed;
420
+
421
+ cutlass::HostTensor<typename B2bConv2d::ElementA, typename B2bConv2d::LayoutA> tensor_A0;
422
+ cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B0;
423
+ cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B0_reordered;
424
+ cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C0;
425
+ cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Scale0;
426
+ cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Bias0;
427
+ cutlass::HostTensor<ElementAccumulator, typename B2bConv2d::LayoutC> tensor_Z0_reference;
428
+ cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D0_reference;
429
+
430
+ cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1;
431
+ cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1_reordered;
432
+ cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C1;
433
+ cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_Bias1;
434
+ cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_computed;
435
+ cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_reference;
436
+
437
+
438
+ public:
439
+
440
+ B2bInterleavedFusedConv2dRun(
441
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
442
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
443
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
444
+ cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
445
+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
446
+ uint64_t seed_ = 2080
447
+ ):
448
+ init_A(init_A_), init_B(init_B_), init_C(init_C_),
449
+ init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) {
450
+
451
+ }
452
+
453
+ /// Helper to initialize a tensor view
454
+ template <typename Element, typename Layout>
455
+ void initialize_tensor(
456
+ cutlass::TensorView<Element, Layout> view,
457
+ cutlass::Distribution::Kind dist_kind,
458
+ uint64_t seed) {
459
+
460
+ if (dist_kind == cutlass::Distribution::Uniform) {
461
+
462
+ int scope;
463
+ int bits = cutlass::sizeof_bits<Element>::value;
464
+
465
+ if (bits <= 16) {
466
+ scope = 2;
467
+ }
468
+ else {
469
+ scope = 8;
470
+ }
471
+ cutlass::reference::host::TensorFillRandomUniform(
472
+ view, seed, scope, -scope, 0);
473
+ }
474
+ else if (dist_kind == cutlass::Distribution::Identity) {
475
+
476
+ cutlass::reference::host::TensorFillIdentity(view);
477
+ }
478
+ else if (dist_kind == cutlass::Distribution::Gaussian) {
479
+
480
+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
481
+ }
482
+ else if (dist_kind == cutlass::Distribution::Sequential) {
483
+
484
+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
485
+ }
486
+ else if (dist_kind == cutlass::Distribution::AllZeros) {
487
+ cutlass::reference::host::TensorFill(view, Element(0));
488
+ }
489
+ else if (dist_kind == cutlass::Distribution::AllOnes) {
490
+ cutlass::reference::host::TensorFill(view, Element(1));
491
+ }
492
+ else {
493
+ }
494
+ }
495
+
496
+ void initialize(
497
+ cutlass::conv::Conv2dProblemSize const &problem_size_0,
498
+ cutlass::conv::Conv2dProblemSize const &problem_size_1,
499
+ ElementCompute alpha0,
500
+ ElementCompute alpha1,
501
+ uint64_t seed = 2019) {
502
+
503
+ tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0));
504
+ tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
505
+ tensor_B0_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0));
506
+ tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
507
+ if(alpha0 == ElementCompute(0)) //per-channel scale
508
+ tensor_Scale0.resize({1, problem_size_0.K});
509
+ tensor_Bias0.resize({1, problem_size_0.K});
510
+ tensor_Z0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
511
+ tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
512
+ tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
513
+ tensor_B1_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
514
+ tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
515
+ tensor_Bias1.resize({1, 1, 1, problem_size_1.K});
516
+ tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
517
+ tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
518
+
519
+ initialize_tensor(tensor_A0.host_view(), init_A, seed);
520
+ initialize_tensor(tensor_B0.host_view(), init_B, seed * 17);
521
+ initialize_tensor(tensor_C0.host_view(), init_C, seed * 39);
522
+ if(alpha0 == ElementCompute(0)) //per-channel scale
523
+ initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed * 61);
524
+ initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83);
525
+ initialize_tensor(tensor_B1.host_view(), init_B, seed * 18);
526
+ initialize_tensor(tensor_C1.host_view(), init_C, seed * 40);
527
+ initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84);
528
+
529
+ //Reorder B0 and B1
530
+ cutlass::reorder_convK<16, InterleavedK>(
531
+ tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_0));
532
+ cutlass::reorder_convK<InterleavedK, InterleavedK>(
533
+ tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_1));
534
+
535
+ tensor_A0.sync_device();
536
+ tensor_B0.sync_device();
537
+ tensor_B0_reordered.sync_device();
538
+ tensor_C0.sync_device();
539
+ if(alpha0 == ElementCompute(0)) //per-channel scale
540
+ tensor_Scale0.sync_device();
541
+ tensor_Bias0.sync_device();
542
+ tensor_D0_reference.sync_device();
543
+ tensor_B1.sync_device();
544
+ tensor_B1_reordered.sync_device();
545
+ tensor_C1.sync_device();
546
+ tensor_Bias1.sync_device();
547
+ tensor_D1_computed.sync_device();
548
+ tensor_D1_reference.sync_device();
549
+ }
550
+
551
+ /// Executes one test
552
+ bool run(
553
+ cutlass::conv::Conv2dProblemSize const &problem_size_0,
554
+ cutlass::conv::Conv2dProblemSize const &problem_size_1,
555
+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
556
+ ElementCompute alpha0 = ElementCompute(1),
557
+ ElementCompute beta0 = ElementCompute(0),
558
+ ElementCompute alpha1 = ElementCompute(1),
559
+ ElementCompute beta1 = ElementCompute(0),
560
+ bool relu = true,
561
+ int warm_ups = 1,
562
+ int runs = 100) {
563
+
564
+ initialize(problem_size_0, problem_size_1, alpha0, alpha1);
565
+
566
+ // configure the operator
567
+ B2bConv2d b2b_conv2d_op;
568
+
569
+ typename B2bConv2d::Arguments b2b_conv2d_args(
570
+ problem_size_0,
571
+ problem_size_1,
572
+ tensor_A0.device_ref(),
573
+ tensor_B0_reordered.device_ref(),
574
+ tensor_C0.device_ref(),
575
+ tensor_Scale0.device_ref(),
576
+ tensor_Bias0.device_ref(),
577
+ tensor_B1_reordered.device_ref(),
578
+ tensor_C1.device_ref(),
579
+ tensor_D1_computed.device_ref(),
580
+ {alpha0, beta0},
581
+ {alpha1, beta1},
582
+ split_k_mode
583
+ );
584
+
585
+ cutlass::Status status = b2b_conv2d_op.can_implement(b2b_conv2d_args);
586
+
587
+ if(status != cutlass::Status::kSuccess) {
588
+ std::cout << "Problem sizes not supported.\n"
589
+ << "Requirments:\n"
590
+ << " problem_size_0.N*P*Q = problem_size_1.N*P*Q\n"
591
+ << " problem_size_0.K = problem_size_1.C\n"
592
+ << " problem_size_1.R = problem_size_1.S = 1\n"
593
+ << " ThreadblockShape0::kN = problem_size_0.K\n"
594
+ << " ThreadblockShape1::kN = problem_size_1.K" << std::endl;
595
+ }
596
+
597
+ CUTLASS_CHECK(status);
598
+
599
+ status = b2b_conv2d_op.initialize(b2b_conv2d_args);
600
+
601
+ CUTLASS_CHECK(status);
602
+
603
+ for(int i = 0; i < warm_ups; i++) {
604
+ status = b2b_conv2d_op();
605
+ CUTLASS_CHECK(status);
606
+ }
607
+
608
+ //
609
+ // Run the Conv2d
610
+ //
611
+
612
+ cudaEvent_t start, stop;
613
+ cudaEventCreate(&start);
614
+ cudaEventCreate(&stop);
615
+
616
+ cudaEventRecord(start);
617
+
618
+ for(int i = 0; i < runs; i++) {
619
+
620
+ // run conv2d operator
621
+ status = b2b_conv2d_op();
622
+ CUTLASS_CHECK(status);
623
+ }
624
+
625
+ cudaEventRecord(stop);
626
+ cudaDeviceSynchronize();
627
+ float conv2dTime;
628
+ cudaEventElapsedTime(&conv2dTime, start, stop);
629
+ std::cout << "Fusion time " << conv2dTime / (float)runs << " ms\n";
630
+
631
+ tensor_D1_computed.sync_host();
632
+
633
+ bool passed = false;
634
+
635
+ cutlass::reference::device::Conv2d<
636
+ typename B2bConv2d::ElementA,
637
+ typename B2bConv2d::LayoutA,
638
+ typename B2bConv2d::ElementB,
639
+ typename B2bConv2d::LayoutB,
640
+ ElementAccumulator,
641
+ typename B2bConv2d::LayoutC,
642
+ ElementAccumulator,
643
+ ElementAccumulator
644
+ >(
645
+ kConvolutionalOperator,
646
+ problem_size_0,
647
+ tensor_A0.device_ref(),
648
+ tensor_B0.device_ref(),
649
+ tensor_Z0_reference.device_ref(),
650
+ tensor_Z0_reference.device_ref(),
651
+ ElementAccumulator(1), // intermediate alpha = 1
652
+ ElementAccumulator(0) // beta = 0
653
+ );
654
+
655
+ cutlass::reference::device::TensorScaleBiasConv2d<
656
+ ElementAccumulator,
657
+ typename B2bConv2d::ElementC,
658
+ typename B2bConv2d::LayoutC,
659
+ ElementCompute,
660
+ typename B2bConv2d::LayoutScaleBias,
661
+ cutlass::NumericConverterClamp<typename B2bConv2d::ElementC, ElementCompute>
662
+ >(
663
+ problem_size_0,
664
+ tensor_Z0_reference.device_ref(),
665
+ tensor_D0_reference.device_ref(),
666
+ alpha0,
667
+ tensor_Scale0.device_ref(),
668
+ tensor_Bias0.device_ref()
669
+ );
670
+
671
+ if(relu) {
672
+ cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view());
673
+ }
674
+
675
+ cutlass::reference::device::Conv2d<
676
+ typename B2bConv2d::ElementA,
677
+ typename B2bConv2d::LayoutA,
678
+ typename B2bConv2d::ElementB,
679
+ typename B2bConv2d::LayoutB,
680
+ typename B2bConv2d::ElementC,
681
+ typename B2bConv2d::LayoutC,
682
+ ElementCompute,
683
+ ElementAccumulator,
684
+ cutlass::NumericConverterClamp<typename B2bConv2d::ElementC, ElementCompute>
685
+ >(
686
+ kConvolutionalOperator,
687
+ problem_size_1,
688
+ tensor_D0_reference.device_ref(),
689
+ tensor_B1.device_ref(),
690
+ tensor_C1.device_ref(),
691
+ tensor_D1_reference.device_ref(),
692
+ alpha1,
693
+ beta1);
694
+
695
+ if(relu) {
696
+ cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view());
697
+ }
698
+
699
+ cudaError_t result = cudaDeviceSynchronize();
700
+ CHECK_TRUE(result == cudaSuccess);
701
+
702
+ // sync host (copy device data to host) for dumping error output in case of mismatches
703
+ tensor_D0_reference.sync_host();
704
+ tensor_D1_reference.sync_host();
705
+
706
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0);
707
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0);
708
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0);
709
+
710
+ passed = cutlass::reference::host::TensorEquals(
711
+ tensor_D1_computed.host_view(),
712
+ tensor_D1_reference.host_view());
713
+
714
+ CHECK_TRUE(passed);
715
+
716
+ if (!passed) {
717
+ std::stringstream fname;
718
+
719
+ fname << "error_B2bImplicitGemm_device_interleaved_fused.txt";
720
+ std::cerr << "Dumping results in " << fname.str() << "\n";
721
+
722
+ std::ofstream results(fname.str());
723
+
724
+ results << problem_size_0 << std::endl;
725
+ results << problem_size_1 << std::endl;
726
+
727
+ results
728
+ << "\nA0:\n" << tensor_A0.host_view() << "\n"
729
+ << "\nB0:\n" << tensor_B0.host_view() << "\n"
730
+ << "\nB0_reordered:\n" << tensor_B0_reordered.host_view() << "\n"
731
+ << "\nC0:\n" << tensor_C0.host_view() << "\n"
732
+ << "\nScale0:\n" << tensor_Scale0.host_view() << "\n"
733
+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
734
+ << "\nB1:\n" << tensor_B1.host_view() << "\n"
735
+ << "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n"
736
+ << "\nC1:\n" << tensor_C1.host_view() << "\n"
737
+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
738
+ << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n"
739
+ << "\nD1 computed:\n" << tensor_D1_computed.host_view();
740
+
741
+
742
+ }
743
+
744
+ return passed;
745
+ }
746
+
747
+ };
748
+
749
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h ADDED
@@ -0,0 +1,798 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ #pragma once
32
+
33
+ #include <iostream>
34
+ #include <fstream>
35
+ #include <sstream>
36
+
37
+ #include "cutlass/util/host_tensor.h"
38
+ #include "cutlass/util/tensor_view_io.h"
39
+ #include "cutlass/util/distribution.h"
40
+ #include "cutlass/util/reference/host/tensor_fill.h"
41
+ #include "cutlass/util/reference/host/tensor_copy.h"
42
+ #include "cutlass/util/reference/host/tensor_compare.h"
43
+ #include "cutlass/util/reference/host/tensor_norm.h"
44
+ #include "cutlass/util/host_reorder.h"
45
+ #include "cutlass/util/reference/device/gemm.h"
46
+ #include "cutlass/util/reference/device/gemm_complex.h"
47
+ #include "cutlass/util/reference/device/tensor_relu.h"
48
+
49
+ #include "reference/device/tensor_scale_bias.h"
50
+ #include "helper.h"
51
+
52
+ #define CHECK_GT(val1, val2) \
53
+ if((val1) <= (val2)) \
54
+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
55
+ #define CHECK_TRUE(val) \
56
+ if(!(val)) \
57
+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
58
+
59
+ template <typename Gemm0_, typename Gemm1_, int InterleavedK_>
60
+ struct B2bInterleavedNonFusedGemmRun
61
+ {
62
+
63
+ using Gemm0 = Gemm0_;
64
+ using Gemm1 = Gemm1_;
65
+ using ElementAccumulator = typename Gemm0::ElementAccumulator;
66
+ using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute;
67
+
68
+ /// Initialization
69
+ cutlass::Distribution::Kind init_A;
70
+ cutlass::Distribution::Kind init_B;
71
+ cutlass::Distribution::Kind init_C;
72
+ cutlass::Distribution::Kind init_Bias;
73
+ uint64_t seed;
74
+
75
+ //
76
+ // Methods
77
+ //
78
+
79
+ B2bInterleavedNonFusedGemmRun(
80
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
81
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
82
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
83
+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
84
+ uint64_t seed_ = 2080
85
+ ):
86
+ init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { }
87
+
88
+ /// Helper to initialize a tensor view
89
+ template <typename Element, typename Layout>
90
+ bool initialize_tensor(
91
+ cutlass::TensorView<Element, Layout> view,
92
+ cutlass::Distribution::Kind dist_kind,
93
+ uint64_t seed) {
94
+
95
+ if (dist_kind == cutlass::Distribution::Uniform) {
96
+
97
+ cutlass::reference::host::TensorFillRandomUniform(
98
+ view, seed, 2, -2, 0);
99
+ }
100
+ else if (dist_kind == cutlass::Distribution::Identity) {
101
+
102
+ cutlass::reference::host::TensorFillIdentity(view);
103
+ }
104
+ else if (dist_kind == cutlass::Distribution::Gaussian) {
105
+
106
+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
107
+ }
108
+ else if (dist_kind == cutlass::Distribution::Sequential) {
109
+
110
+ cutlass::reference::host::BlockFillSequential(
111
+ view.data(), view.capacity());
112
+ }
113
+ else if (dist_kind == cutlass::Distribution::AllZeros) {
114
+ cutlass::reference::host::TensorFill(view, Element(0));
115
+ }
116
+ else if (dist_kind == cutlass::Distribution::AllOnes) {
117
+ cutlass::reference::host::TensorFill(view, Element(1));
118
+ }
119
+ else {
120
+ std::cerr << "Not implemented\n";
121
+ return false;
122
+ }
123
+
124
+ return true;
125
+ }
126
+
127
+
128
+
129
+
130
+ /// Executes one test
131
+ bool run(
132
+ cutlass::gemm::GemmCoord problem_size_0,
133
+ cutlass::gemm::GemmCoord problem_size_1,
134
+ ElementCompute alpha0 = ElementCompute(1),
135
+ ElementCompute beta0 = ElementCompute(0),
136
+ ElementCompute alpha1 = ElementCompute(1),
137
+ ElementCompute beta1 = ElementCompute(0),
138
+ bool relu = true,
139
+ int warm_ups = 1,
140
+ int runs = 100) {
141
+
142
+ //
143
+ // Allocate the GEMM workspace
144
+ //
145
+
146
+ cutlass::HostTensor<
147
+ typename Gemm0::ElementA,
148
+ typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
149
+
150
+ cutlass::HostTensor<
151
+ typename Gemm0::ElementB,
152
+ typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
153
+
154
+ cutlass::HostTensor<
155
+ typename Gemm0::ElementB,
156
+ typename Gemm0::LayoutB> tensor_B0_reordered(problem_size_0.kn());
157
+
158
+ cutlass::HostTensor<
159
+ typename Gemm0::ElementC,
160
+ typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
161
+
162
+ cutlass::HostTensor<
163
+ typename Gemm0::ElementC,
164
+ typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
165
+
166
+ cutlass::HostTensor<
167
+ typename Gemm0::ElementC,
168
+ typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
169
+
170
+ cutlass::HostTensor<
171
+ typename Gemm0::ElementC,
172
+ typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
173
+
174
+ cutlass::HostTensor<
175
+ typename Gemm1::ElementB,
176
+ typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
177
+
178
+ cutlass::HostTensor<
179
+ typename Gemm1::ElementB,
180
+ typename Gemm1::LayoutB> tensor_B1_reordered(problem_size_1.kn());
181
+
182
+ cutlass::HostTensor<
183
+ typename Gemm1::ElementC,
184
+ typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
185
+
186
+ cutlass::HostTensor<
187
+ typename Gemm0::ElementC,
188
+ typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
189
+
190
+ cutlass::HostTensor<
191
+ typename Gemm1::ElementC,
192
+ typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
193
+
194
+ cutlass::HostTensor<
195
+ typename Gemm1::ElementC,
196
+ typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
197
+
198
+ CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
199
+ CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
200
+ CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
201
+ CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014));
202
+ CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
203
+ CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
204
+ CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013));
205
+
206
+ //Reorder B0 and B1
207
+ cutlass::reorder_column<InterleavedK_>(
208
+ tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0);
209
+ cutlass::reorder_column<InterleavedK_>(
210
+ tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1);
211
+
212
+ cutlass::reference::host::TensorFill(
213
+ tensor_D0.host_view());
214
+ cutlass::reference::host::TensorFill(
215
+ tensor_D1.host_view());
216
+ cutlass::reference::host::TensorFill(
217
+ reference_D0.host_view());
218
+ cutlass::reference::host::TensorFill(
219
+ reference_D1.host_view());
220
+
221
+ tensor_A0.sync_device();
222
+ tensor_B0.sync_device();
223
+ tensor_B0_reordered.sync_device();
224
+ tensor_C0.sync_device();
225
+ tensor_Bias0.sync_device();
226
+ tensor_D0.sync_device();
227
+ tensor_B1.sync_device();
228
+ tensor_B1_reordered.sync_device();
229
+ tensor_C1.sync_device();
230
+ tensor_Bias1.sync_device();
231
+ tensor_D1.sync_device();
232
+ reference_D0.sync_device();
233
+ reference_D1.sync_device();
234
+
235
+ //
236
+ // Initialize the GEMM operator
237
+ //
238
+
239
+ typename Gemm0::Arguments arguments_0{
240
+ problem_size_0,
241
+ tensor_A0.device_ref(),
242
+ tensor_B0_reordered.device_ref(),
243
+ {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
244
+ tensor_D0.device_ref(),
245
+ {alpha0, beta0}
246
+ };
247
+
248
+ typename Gemm1::Arguments arguments_1{
249
+ problem_size_1,
250
+ tensor_D0.device_ref(),
251
+ tensor_B1_reordered.device_ref(),
252
+ {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
253
+ tensor_D1.device_ref(),
254
+ {alpha1, beta1}
255
+ };
256
+
257
+
258
+ Gemm0 gemm_op_0;
259
+ Gemm1 gemm_op_1;
260
+
261
+ cutlass::Status status = gemm_op_0.initialize(arguments_0);
262
+
263
+ CUTLASS_CHECK(status);
264
+
265
+ status = gemm_op_1.initialize(arguments_1);
266
+
267
+ CUTLASS_CHECK(status);
268
+
269
+ for(int i = 0; i < warm_ups; i++) {
270
+ status = gemm_op_0();
271
+ CUTLASS_CHECK(status);
272
+ status = gemm_op_1();
273
+ CUTLASS_CHECK(status);
274
+ }
275
+
276
+ //
277
+ // Run the GEMM
278
+ //
279
+ cudaEvent_t start, stop1, stop2;
280
+ cudaEventCreate(&start);
281
+ cudaEventCreate(&stop1);
282
+ cudaEventCreate(&stop2);
283
+
284
+ cudaEventRecord(start);
285
+
286
+ for(int i = 0; i < runs; i++) {
287
+ status = gemm_op_0();
288
+
289
+ CUTLASS_CHECK(status);
290
+ }
291
+ cudaEventRecord(stop1);
292
+ for(int i = 0; i < runs; i++) {
293
+ status = gemm_op_1();
294
+
295
+ CUTLASS_CHECK(status);
296
+ }
297
+
298
+ cudaEventRecord(stop2);
299
+ cudaDeviceSynchronize();
300
+ float gemm0Time, gemm1Time, totalTime;
301
+ cudaEventElapsedTime(&gemm0Time, start, stop1);
302
+ cudaEventElapsedTime(&gemm1Time, stop1, stop2);
303
+ cudaEventElapsedTime(&totalTime, start, stop2);
304
+ std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n";
305
+ std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n";
306
+ std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n";
307
+
308
+ tensor_D0.sync_host();
309
+ tensor_D1.sync_host();
310
+
311
+ //
312
+ // Verify
313
+ //
314
+ cutlass::reference::device::Gemm<
315
+ typename Gemm0::ElementA, typename Gemm0::LayoutA,
316
+ typename Gemm0::ElementB, typename Gemm0::LayoutB,
317
+ typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute,
318
+ ElementAccumulator, typename Gemm0::Operator>
319
+ reference_gemm_0;
320
+
321
+ cutlass::reference::device::Gemm<
322
+ typename Gemm1::ElementA, typename Gemm1::LayoutA,
323
+ typename Gemm1::ElementB, typename Gemm1::LayoutB,
324
+ typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute,
325
+ ElementAccumulator, typename Gemm1::Operator>
326
+ reference_gemm_1;
327
+
328
+ reference_gemm_0(
329
+ problem_size_0,
330
+ alpha0,
331
+ tensor_A0.device_ref(),
332
+ tensor_B0.device_ref(),
333
+ beta0,
334
+ {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
335
+ reference_D0.device_ref()
336
+ );
337
+
338
+ if(relu) {
339
+ cutlass::reference::device::TensorReLu(reference_D0.device_view());
340
+ }
341
+
342
+ reference_gemm_1(
343
+ problem_size_1,
344
+ alpha1,
345
+ reference_D0.device_ref(),
346
+ tensor_B1.device_ref(),
347
+ beta1,
348
+ {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
349
+ reference_D1.device_ref()
350
+ );
351
+
352
+ if(relu) {
353
+ cutlass::reference::device::TensorReLu(reference_D1.device_view());
354
+ }
355
+
356
+ // Wait for kernels to finish
357
+ cudaDeviceSynchronize();
358
+ reference_D0.sync_host();
359
+ reference_D1.sync_host();
360
+
361
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
362
+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
363
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
364
+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
365
+
366
+ bool passed = cutlass::reference::host::TensorEquals(
367
+ reference_D1.host_view(),
368
+ tensor_D1.host_view());
369
+
370
+ CHECK_TRUE(passed);
371
+ if (!passed) {
372
+
373
+ std::stringstream fname;
374
+
375
+ fname << "error_B2bGemm_device_interleaved_nonfused.txt";
376
+ std::cerr << "Dumping results in " << fname.str() << "\n";
377
+
378
+ std::ofstream file(fname.str());
379
+
380
+ file
381
+ << "A0 =\n" << tensor_A0.host_view()
382
+ << "\nB0 =\n" << tensor_B0.host_view()
383
+ << "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
384
+ << "\nC0 =\n" << tensor_C0.host_view()
385
+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
386
+ << "\nD0 =\n" << tensor_D0.host_view()
387
+ << "\nB1 =\n" << tensor_B1.host_view()
388
+ << "\nB1_reordered =\n" << tensor_B1_reordered.host_view()
389
+ << "\nC1 =\n" << tensor_C1.host_view()
390
+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
391
+ << "\n\nReference =\n" << reference_D1.host_view()
392
+ << "\nComputed =\n" << tensor_D1.host_view();
393
+ }
394
+ return passed;
395
+ }
396
+ };
397
+
398
+ template <typename B2bGemm_, int InterleavedK_>
399
+ struct B2bInterleavedFusedGemmRun
400
+ {
401
+
402
+ using B2bGemm = B2bGemm_;
403
+ using ElementAccumulator = typename B2bGemm::ElementAccumulator;
404
+ using ElementCompute = typename B2bGemm::B2bGemmKernel::Epilogue::OutputOp::ElementCompute;
405
+
406
+ /// Initialization
407
+ cutlass::Distribution::Kind init_A;
408
+ cutlass::Distribution::Kind init_B;
409
+ cutlass::Distribution::Kind init_C;
410
+ cutlass::Distribution::Kind init_Scale;
411
+ cutlass::Distribution::Kind init_Bias;
412
+ uint64_t seed;
413
+
414
+ //
415
+ // Methods
416
+ //
417
+
418
+ B2bInterleavedFusedGemmRun(
419
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
420
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
421
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
422
+ cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
423
+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
424
+ uint64_t seed_ = 2080
425
+ ):
426
+ init_A(init_A_), init_B(init_B_), init_C(init_C_),
427
+ init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
428
+
429
+ /// Helper to initialize a tensor view
430
+ template <typename Element, typename Layout>
431
+ bool initialize_tensor(
432
+ cutlass::TensorView<Element, Layout> view,
433
+ cutlass::Distribution::Kind dist_kind,
434
+ uint64_t seed) {
435
+
436
+ if (dist_kind == cutlass::Distribution::Uniform) {
437
+
438
+ cutlass::reference::host::TensorFillRandomUniform(
439
+ view, seed, 2, -2, 0);
440
+ }
441
+ else if (dist_kind == cutlass::Distribution::Identity) {
442
+
443
+ cutlass::reference::host::TensorFillIdentity(view);
444
+ }
445
+ else if (dist_kind == cutlass::Distribution::Gaussian) {
446
+
447
+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
448
+ }
449
+ else if (dist_kind == cutlass::Distribution::Sequential) {
450
+
451
+ cutlass::reference::host::BlockFillSequential(
452
+ view.data(), view.capacity());
453
+ }
454
+ else if (dist_kind == cutlass::Distribution::AllZeros) {
455
+ cutlass::reference::host::TensorFill(view, Element(0));
456
+ }
457
+ else if (dist_kind == cutlass::Distribution::AllOnes) {
458
+ cutlass::reference::host::TensorFill(view, Element(1));
459
+ }
460
+ else {
461
+ std::cerr << "Not implemented\n";
462
+ return false;
463
+ }
464
+
465
+ return true;
466
+ }
467
+
468
+
469
+
470
+
471
+ /// Executes one test
472
+ bool run(
473
+ cutlass::gemm::GemmCoord problem_size_0,
474
+ cutlass::gemm::GemmCoord problem_size_1,
475
+ ElementCompute alpha0 = ElementCompute(1),
476
+ ElementCompute beta0 = ElementCompute(0),
477
+ ElementCompute alpha1 = ElementCompute(1),
478
+ ElementCompute beta1 = ElementCompute(0),
479
+ cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
480
+
481
+ // batch_count is used as split-k when mode is kGemm according
482
+ // to the GemmUniversal interface
483
+
484
+ int batch_count = 1,
485
+
486
+ int64_t batch_stride_A0 = 0,
487
+ int64_t batch_stride_B0 = 0,
488
+ int64_t batch_stride_C0 = 0,
489
+ int64_t batch_stride_B1 = 0,
490
+ int64_t batch_stride_C1 = 0,
491
+ int64_t batch_stride_D1 = 0,
492
+ int64_t batch_stride_Bias0 = 0,
493
+ int64_t batch_stride_Scale0 = 0,
494
+ bool relu = true,
495
+ int warm_ups = 1,
496
+ int runs = 100) {
497
+
498
+ //
499
+ // Allocate the GEMM workspace
500
+ //
501
+
502
+ cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
503
+ cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
504
+ cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k());
505
+ cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k());
506
+ cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k());
507
+
508
+ cutlass::HostTensor<
509
+ typename B2bGemm::ElementA,
510
+ typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk());
511
+
512
+ cutlass::HostTensor<
513
+ typename B2bGemm::ElementB,
514
+ typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn());
515
+
516
+ cutlass::HostTensor<
517
+ typename B2bGemm::ElementB,
518
+ typename B2bGemm::LayoutB> tensor_B0_reordered(CoordB0.kn());
519
+
520
+ cutlass::HostTensor<
521
+ typename B2bGemm::ElementC,
522
+ typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn());
523
+
524
+ cutlass::HostTensor<
525
+ typename B2bGemm::ElementScaleBias,
526
+ typename B2bGemm::LayoutScaleBias> tensor_Scale0;
527
+
528
+ if(alpha0 == ElementCompute(0)) //per-channel scale
529
+ tensor_Scale0.resize({1, batch_count * problem_size_0.n()});
530
+
531
+ cutlass::HostTensor<
532
+ typename B2bGemm::ElementScaleBias,
533
+ typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()});
534
+
535
+ cutlass::HostTensor<
536
+ ElementAccumulator,
537
+ typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn());
538
+
539
+ cutlass::HostTensor<
540
+ typename B2bGemm::ElementC,
541
+ typename B2bGemm::LayoutC> reference_D0(CoordC0.mn());
542
+
543
+ cutlass::HostTensor<
544
+ typename B2bGemm::ElementB,
545
+ typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn());
546
+
547
+ cutlass::HostTensor<
548
+ typename B2bGemm::ElementB,
549
+ typename B2bGemm::LayoutB> tensor_B1_reordered(CoordB1.kn());
550
+
551
+ cutlass::HostTensor<
552
+ typename B2bGemm::ElementC,
553
+ typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn());
554
+
555
+ cutlass::HostTensor<
556
+ typename B2bGemm::ElementC,
557
+ typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()});
558
+
559
+ cutlass::HostTensor<
560
+ typename B2bGemm::ElementC,
561
+ typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn());
562
+
563
+ cutlass::HostTensor<
564
+ typename B2bGemm::ElementC,
565
+ typename B2bGemm::LayoutC> reference_D1(CoordC1.mn());
566
+
567
+
568
+ CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
569
+ CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
570
+ CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
571
+ if(alpha0 == ElementCompute(0)) //per-channel scale
572
+ CHECK_TRUE(initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed + 2014));
573
+ CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2013));
574
+ CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
575
+ CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
576
+ CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012));
577
+
578
+ //Reorder B0
579
+ cutlass::reorder_column<16>(
580
+ tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), CoordB0);
581
+ cutlass::reorder_column<InterleavedK_>(
582
+ tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), CoordB1);
583
+
584
+ cutlass::reference::host::TensorFill(
585
+ tensor_D1.host_view());
586
+ cutlass::reference::host::TensorFill(
587
+ reference_D0.host_view());
588
+ cutlass::reference::host::TensorFill(
589
+ reference_D1.host_view());
590
+
591
+ tensor_A0.sync_device();
592
+ tensor_B0.sync_device();
593
+ tensor_B0_reordered.sync_device();
594
+ tensor_C0.sync_device();
595
+ if(alpha0 == ElementCompute(0)) //per-channel scale
596
+ tensor_Scale0.sync_device();
597
+ tensor_Bias0.sync_device();
598
+ tensor_B1.sync_device();
599
+ tensor_B1_reordered.sync_device();
600
+ tensor_C1.sync_device();
601
+ tensor_Bias1.sync_device();
602
+ tensor_D1.sync_device();
603
+ reference_D0.sync_device();
604
+ reference_D1.sync_device();
605
+ // tensor_Bias0_batched.sync_device();
606
+
607
+ //
608
+ // Initialize the GEMM operator
609
+ //
610
+
611
+ typename B2bGemm::Arguments arguments{
612
+ mode,
613
+ problem_size_0,
614
+ problem_size_1,
615
+ tensor_A0.device_ref(),
616
+ tensor_B0_reordered.device_ref(),
617
+ tensor_C0.device_ref(),
618
+ tensor_Scale0.device_ref(),
619
+ tensor_Bias0.device_ref(),
620
+ tensor_B1_reordered.device_ref(),
621
+ {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
622
+ tensor_D1.device_ref(),
623
+ batch_stride_A0,
624
+ batch_stride_B0,
625
+ batch_stride_B1,
626
+ batch_stride_C1,
627
+ batch_stride_D1,
628
+ batch_stride_Bias0,
629
+ batch_stride_Scale0,
630
+ {alpha0, beta0},
631
+ {alpha1, beta1},
632
+ batch_count,
633
+ };
634
+
635
+ B2bGemm b2b_gemm_op;
636
+
637
+ cutlass::Status status = b2b_gemm_op.can_implement(arguments);
638
+
639
+ if(status != cutlass::Status::kSuccess) {
640
+ std::cout << "Problem sizes not supported.\n"
641
+ << "Requirments:\n"
642
+ << " problem_size_0.M = problem_size_1.M\n"
643
+ << " problem_size_0.N = problem_size_1.K\n"
644
+ << " ThreadblockShape0::kN = problem_size_0.N\n"
645
+ << " ThreadblockShape1::kN = problem_size_1.N" << std::endl;
646
+ }
647
+
648
+ status = b2b_gemm_op.initialize(arguments);
649
+
650
+ CUTLASS_CHECK(status);
651
+
652
+ for(int i = 0; i < warm_ups; i++) {
653
+ status = b2b_gemm_op();
654
+ CUTLASS_CHECK(status);
655
+ }
656
+
657
+ //
658
+ // Run the GEMM
659
+ //
660
+
661
+ cudaEvent_t start, stop;
662
+ cudaEventCreate(&start);
663
+ cudaEventCreate(&stop);
664
+
665
+ cudaEventRecord(start);
666
+
667
+ for(int i = 0; i < runs; i++) {
668
+ status = b2b_gemm_op();
669
+
670
+ CUTLASS_CHECK(status);
671
+ }
672
+
673
+ cudaEventRecord(stop);
674
+ cudaDeviceSynchronize();
675
+ float gemmTime;
676
+ cudaEventElapsedTime(&gemmTime, start, stop);
677
+ std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
678
+
679
+ tensor_D1.sync_host();
680
+
681
+ //
682
+ // Verify
683
+ //
684
+
685
+ cutlass::reference::device::GemmComplex<
686
+ typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
687
+ typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
688
+ ElementAccumulator, typename B2bGemm::LayoutC,
689
+ ElementAccumulator, ElementAccumulator
690
+ >(
691
+ problem_size_0,
692
+ ElementAccumulator(1), //intermediate alpha=1
693
+ tensor_A0.device_ref(),
694
+ cutlass::ComplexTransform::kNone,
695
+ tensor_B0.device_ref(),
696
+ cutlass::ComplexTransform::kNone,
697
+ ElementAccumulator(0), //beta = 0
698
+ reference_Z0.device_ref(),
699
+ reference_Z0.device_ref(),
700
+ ElementAccumulator(0),
701
+ int(batch_count),
702
+ batch_stride_A0,
703
+ batch_stride_B0,
704
+ batch_stride_C0,
705
+ batch_stride_C0
706
+ );
707
+
708
+ cutlass::reference::device::TensorScaleBiasGemmBatched<
709
+ ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
710
+ ElementCompute, typename B2bGemm::LayoutScaleBias
711
+ > (
712
+ problem_size_0,
713
+ reference_Z0.device_ref(),
714
+ reference_D0.device_ref(),
715
+ alpha0,
716
+ tensor_Scale0.device_ref(),
717
+ tensor_Bias0.device_ref(),
718
+ int(batch_count),
719
+ batch_stride_C0,
720
+ batch_stride_C0,
721
+ batch_stride_Scale0,
722
+ batch_stride_Bias0
723
+ );
724
+
725
+ if(relu) {
726
+ cutlass::reference::device::TensorReLu(reference_D0.device_view());
727
+ }
728
+
729
+ cutlass::reference::device::GemmComplex<
730
+ typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
731
+ typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
732
+ typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
733
+ ElementCompute, ElementAccumulator
734
+ >(
735
+ problem_size_1,
736
+ alpha1, //intermediate alpha=1
737
+ reference_D0.device_ref(),
738
+ cutlass::ComplexTransform::kNone,
739
+ tensor_B1.device_ref(),
740
+ cutlass::ComplexTransform::kNone,
741
+ beta1, //beta = 0
742
+ {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
743
+ reference_D1.device_ref(),
744
+ ElementAccumulator(0),
745
+ int(batch_count),
746
+ batch_stride_C0,
747
+ batch_stride_B1,
748
+ batch_stride_C1,
749
+ batch_stride_D1
750
+ );
751
+
752
+ if(relu) {
753
+ cutlass::reference::device::TensorReLu(reference_D1.device_view());
754
+ }
755
+
756
+ cudaDeviceSynchronize();
757
+ reference_D0.sync_host();
758
+ reference_D1.sync_host();
759
+
760
+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
761
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
762
+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
763
+
764
+ bool passed = cutlass::reference::host::TensorEquals(
765
+ reference_D1.host_view(),
766
+ tensor_D1.host_view());
767
+
768
+ CHECK_TRUE(passed);
769
+ if (!passed)
770
+ {
771
+
772
+ std::stringstream fname;
773
+
774
+ fname << "error_B2bGemm_device_interleaved_fused.txt";
775
+ std::cerr << "Dumping results in " << fname.str() << "\n";
776
+
777
+ std::ofstream file(fname.str());
778
+
779
+ file
780
+ << "A0 =\n" << tensor_A0.host_view()
781
+ << "\nB0 =\n" << tensor_B0.host_view()
782
+ << "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
783
+ << "\nC0 =\n" << tensor_C0.host_view()
784
+ << "\nScale0:\n" << tensor_Scale0.host_view() << "\n"
785
+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
786
+ << "\nB1 =\n" << tensor_B1.host_view()
787
+ << "\nB1_reordered =\n" << tensor_B1_reordered.host_view()
788
+ << "\nC1 =\n" << tensor_C1.host_view()
789
+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
790
+ << "\n\nReference =\n" << reference_D1.host_view()
791
+ << "\nComputed =\n" << tensor_D1.host_view();
792
+ }
793
+ return passed;
794
+ }
795
+
796
+ };
797
+
798
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/device/b2b_gemm.h ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/numeric_types.h"
39
+ #include "cutlass/arch/arch.h"
40
+ #include "cutlass/device_kernel.h"
41
+
42
+ #include "cutlass/gemm/threadblock/threadblock_swizzle.h"
43
+
44
+ #include "cutlass/gemm/device/default_gemm_configuration.h"
45
+ #include "cutlass/epilogue/thread/linear_combination_relu.h"
46
+
47
+ #include "kernel/b2b_gemm.h"
48
+ #include "kernel/default_b2b_gemm.h"
49
+ #include "kernel/default_b2b_gemm_smem_accumulator.h"
50
+
51
+ ////////////////////////////////////////////////////////////////////////////////
52
+
53
+ namespace cutlass {
54
+ namespace gemm {
55
+ namespace device {
56
+
57
+ /////////////////////////////////////////////////////////////////////////////////////////////////
58
+
59
+ template <
60
+ /// Element type for A matrix operand
61
+ typename ElementA_,
62
+ /// Layout type for A matrix operand
63
+ typename LayoutA_,
64
+ /// Element type for B matrix operand
65
+ typename ElementB_,
66
+ /// Layout type for B matrix operand
67
+ typename LayoutB_,
68
+ /// Element type for C and D matrix operands
69
+ typename ElementC_,
70
+ /// Layout type for C and D matrix operands
71
+ typename LayoutC_,
72
+ /// Element type for internal accumulation
73
+ typename ElementAccumulator_ = ElementC_,
74
+ /// Operator class tag
75
+ typename OperatorClass_ = arch::OpClassSimt,
76
+ /// Tag indicating architecture to tune for
77
+ typename ArchTag_ = arch::Sm70,
78
+ /// Threadblock-level tile size (concept: GemmShape)
79
+ typename ThreadblockShape0_ = typename DefaultGemmConfiguration<
80
+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
81
+ ElementAccumulator_>::ThreadblockShape,
82
+ /// Threadblock-level tile size (concept: GemmShape)
83
+ typename ThreadblockShape1_ = typename DefaultGemmConfiguration<
84
+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
85
+ ElementAccumulator_>::ThreadblockShape,
86
+ /// Warp-level tile size (concept: GemmShape)
87
+ typename WarpShape0_ = typename DefaultGemmConfiguration<
88
+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
89
+ ElementAccumulator_>::WarpShape,
90
+ /// Warp-level tile size (concept: GemmShape)
91
+ typename WarpShape1_ = typename DefaultGemmConfiguration<
92
+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
93
+ ElementAccumulator_>::WarpShape,
94
+ /// Instruction-level tile size (concept: GemmShape)
95
+ typename InstructionShape_ = typename DefaultGemmConfiguration<
96
+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
97
+ ElementAccumulator_>::InstructionShape,
98
+ /// Epilogue output operator
99
+ typename EpilogueOutputOp0_ = typename DefaultGemmConfiguration<
100
+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
101
+ ElementAccumulator_>::EpilogueOutputOp,
102
+ /// Epilogue output operator
103
+ typename EpilogueOutputOp1_ = typename DefaultGemmConfiguration<
104
+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
105
+ ElementAccumulator_>::EpilogueOutputOp,
106
+ /// Threadblock-level swizzling operator
107
+ typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
108
+ /// Number of stages used in the pipelined mainloop
109
+ int Stages =
110
+ DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
111
+ ElementC_, ElementAccumulator_>::kStages,
112
+ /// Stage accumulator in shared memory
113
+ bool SmemAccumulator = false,
114
+ /// Access granularity of A matrix in units of elements
115
+ int AlignmentA =
116
+ DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
117
+ ElementC_, ElementAccumulator_>::kAlignmentA,
118
+ /// Access granularity of B matrix in units of elements
119
+ int AlignmentB =
120
+ DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
121
+ ElementC_, ElementAccumulator_>::kAlignmentB,
122
+ /// Operation performed by GEMM
123
+ typename Operator_ = typename DefaultGemmConfiguration<
124
+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
125
+ ElementAccumulator_>::Operator>
126
+ class B2bGemm {
127
+ public:
128
+
129
+ using ElementA = ElementA_;
130
+ using LayoutA = LayoutA_;
131
+ using TensorRefA = TensorRef<ElementA const, LayoutA>;
132
+ using ElementB = ElementB_;
133
+ using LayoutB = LayoutB_;
134
+ using TensorRefB = TensorRef<ElementB const, LayoutB>;
135
+ using ElementC = ElementC_;
136
+ using LayoutC = LayoutC_;
137
+ using TensorRefC = TensorRef<ElementC const, LayoutC>;
138
+ using TensorRefD = TensorRef<ElementC, LayoutC>;
139
+ using ElementAccumulator = ElementAccumulator_;
140
+ using OperatorClass = OperatorClass_;
141
+ using ArchTag = ArchTag_;
142
+ using ThreadblockShape0 = ThreadblockShape0_;
143
+ using ThreadblockShape1 = ThreadblockShape1_;
144
+ using WarpShape0 = WarpShape0_;
145
+ using WarpShape1 = WarpShape1_;
146
+ using InstructionShape = InstructionShape_;
147
+ using EpilogueOutputOp0 = EpilogueOutputOp0_;
148
+ using EpilogueOutputOp1 = EpilogueOutputOp1_;
149
+ using ThreadblockSwizzle = ThreadblockSwizzle_;
150
+ using Operator = Operator_;
151
+ static int const kStages = Stages;
152
+ static int const kAlignmentA = AlignmentA;
153
+ static int const kAlignmentB = AlignmentB;
154
+ static int const kAlignmentC = EpilogueOutputOp1::kCount;
155
+ static ComplexTransform const kTransformA = ComplexTransform::kNone;
156
+ static ComplexTransform const kTransformB = ComplexTransform::kNone;
157
+
158
+ /// Derived types
159
+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
160
+ using LayoutScaleBias = layout::RowMajor;
161
+
162
+ /// Define the kernel
163
+ using B2bGemmKernel = typename kernel::DefaultB2bGemm<
164
+ ElementA,
165
+ LayoutA,
166
+ kAlignmentA,
167
+ ElementB,
168
+ LayoutB,
169
+ kAlignmentB,
170
+ ElementC,
171
+ LayoutC,
172
+ ElementAccumulator,
173
+ OperatorClass,
174
+ ArchTag,
175
+ ThreadblockShape0,
176
+ ThreadblockShape1,
177
+ WarpShape0,
178
+ WarpShape1,
179
+ InstructionShape,
180
+ EpilogueOutputOp0,
181
+ EpilogueOutputOp1,
182
+ ThreadblockSwizzle,
183
+ kStages,
184
+ Operator,
185
+ SmemAccumulator
186
+ >::B2bGemmKernel;
187
+
188
+ using Arguments = typename B2bGemmKernel::Arguments;
189
+
190
+ private:
191
+
192
+ /// Kernel parameters object
193
+ typename B2bGemmKernel::Params params_;
194
+
195
+ public:
196
+
197
+ /// Constructs the GEMM.
198
+ B2bGemm() { }
199
+
200
+ /// Determines whether the GEMM can execute the given problem.
201
+ static Status can_implement(Arguments const &args) {
202
+
203
+ Status status = B2bGemmKernel::can_implement(
204
+ args.problem_size_0,
205
+ args.problem_size_1,
206
+ args.ref_A0.non_const_ref(),
207
+ args.ref_B0.non_const_ref(),
208
+ args.ref_C0.non_const_ref(),
209
+ args.ref_B1.non_const_ref(),
210
+ args.ref_C1.non_const_ref(),
211
+ args.ref_D1
212
+ );
213
+
214
+ if (status != Status::kSuccess) {
215
+ return status;
216
+ }
217
+
218
+ return Status::kSuccess;
219
+ }
220
+
221
+ /// Gets the workspace size
222
+ static size_t get_workspace_size(Arguments const &args) {
223
+
224
+ size_t bytes = 0;
225
+
226
+ // Determine grid shape
227
+ ThreadblockSwizzle threadblock_swizzle;
228
+
229
+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
230
+ args.problem_size_0,
231
+ {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
232
+ args.batch_count);
233
+
234
+ return bytes;
235
+ }
236
+
237
+ /// Initializes GEMM state from arguments.
238
+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
239
+
240
+ // Determine grid shape
241
+ ThreadblockSwizzle threadblock_swizzle;
242
+
243
+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
244
+ args.problem_size_0,
245
+ {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
246
+ args.batch_count);
247
+ // cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape(
248
+ // args.problem_size_1,
249
+ // {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK},
250
+ // args.batch_count);
251
+
252
+ // Initialize the Params structure
253
+ params_ = typename B2bGemmKernel::Params{
254
+ args.mode,
255
+ args.problem_size_0,
256
+ args.problem_size_1,
257
+ grid_shape,
258
+ args.ref_A0.non_const_ref(),
259
+ args.ref_B0.non_const_ref(),
260
+ args.ref_C0.non_const_ref(),
261
+ args.ref_Scale0.non_const_ref(),
262
+ args.ref_Bias0.non_const_ref(),
263
+ args.ref_B1.non_const_ref(),
264
+ args.ref_C1.non_const_ref(),
265
+ args.ref_D1,
266
+ args.batch_stride_A0,
267
+ args.batch_stride_B0,
268
+ args.batch_stride_B1,
269
+ args.batch_stride_C1,
270
+ args.batch_stride_D1,
271
+ args.batch_stride_Bias0,
272
+ args.batch_stride_Scale0,
273
+ args.epilogue0,
274
+ args.epilogue1,
275
+ static_cast<int *>(workspace),
276
+ };
277
+
278
+ return Status::kSuccess;
279
+ }
280
+
281
+ /// Lightweight update given a subset of arguments
282
+ Status update(Arguments const &args, void *workspace = nullptr) {
283
+
284
+ params_.ref_A0.reset(args.ref_A0.non_const_ref().data());
285
+ params_.ref_B0.reset(args.ref_B0.non_const_ref().data());
286
+ params_.ref_C0.reset(args.ref_C0.non_const_ref().data());
287
+ params_.ref_Scale0.reset(args.ref_Scale0.non_const_ref().data());
288
+ params_.ref_Bias0.reset(args.ref_Bias0.non_const_ref().data());
289
+ params_.ref_B1.reset(args.ref_B1.non_const_ref().data());
290
+ params_.ref_C1.reset(args.ref_C1.non_const_ref().data());
291
+ params_.ref_D1.reset(args.ref_D1.data());
292
+ params_.output_op_0 = args.epilogue0;
293
+ params_.output_op_1 = args.epilogue1;
294
+ params_.semaphore = static_cast<int *>(workspace);
295
+
296
+ return Status::kSuccess;
297
+ }
298
+
299
+ /// Runs the kernel using initialized state.
300
+ Status run(cudaStream_t stream = nullptr) {
301
+
302
+ ThreadblockSwizzle threadblock_swizzle;
303
+
304
+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
305
+ dim3 block(B2bGemmKernel::kThreadCount, 1, 1);
306
+
307
+ cudaError_t result;
308
+
309
+ int smem_size = int(sizeof(typename B2bGemmKernel::SharedStorage));
310
+ if (smem_size >= (48 << 10)) {
311
+ result = cudaFuncSetAttribute(Kernel<B2bGemmKernel>,
312
+ cudaFuncAttributeMaxDynamicSharedMemorySize,
313
+ smem_size);
314
+
315
+ if (result != cudaSuccess) {
316
+ return Status::kErrorInternal;
317
+ }
318
+ }
319
+
320
+ cutlass::Kernel<B2bGemmKernel><<<grid, block, smem_size, stream>>>(params_);
321
+
322
+ result = cudaGetLastError();
323
+
324
+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
325
+ }
326
+
327
+ /// Runs the kernel using initialized state.
328
+ Status operator()(cudaStream_t stream = nullptr) {
329
+ return run(stream);
330
+ }
331
+
332
+ /// Runs the kernel using initialized state.
333
+ Status operator()(
334
+ Arguments const &args,
335
+ void *workspace = nullptr,
336
+ cudaStream_t stream = nullptr) {
337
+
338
+ Status status = initialize(args, workspace, stream);
339
+
340
+ if (status == Status::kSuccess) {
341
+ status = run(stream);
342
+ }
343
+
344
+ return status;
345
+ }
346
+ };
347
+
348
+ } // namespace device
349
+ } // namespace gemm
350
+ } // namespace cutlass
351
+
352
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /* \file
32
+ \brief Template for device-level Implicit GEMM
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include <limits>
38
+
39
+ #include "cutlass/cutlass.h"
40
+ #include "cutlass/device_kernel.h"
41
+ #include "cutlass/conv/convolution.h"
42
+
43
+ #include "kernel/b2b_implicit_gemm_convolution.h"
44
+ #include "kernel/default_b2b_conv2d_fprop.h"
45
+ #include "kernel/default_b2b_conv2d_fprop_sm75.h"
46
+ #include "kernel/default_b2b_conv2d_fprop_sm80.h"
47
+ #include "kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h"
48
+ #include "kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h"
49
+
50
+ namespace cutlass {
51
+ namespace conv {
52
+ namespace device {
53
+
54
+ template<typename B2bImplicitGemmKernel_>
55
+ class B2bImplicitGemmConvolution {
56
+ public:
57
+
58
+ using B2bImplicitGemmKernel = B2bImplicitGemmKernel_;
59
+
60
+ using ElementA = typename B2bImplicitGemmKernel::ElementA;
61
+ using LayoutA = typename B2bImplicitGemmKernel::LayoutA;
62
+ using ElementB = typename B2bImplicitGemmKernel::ElementB;
63
+ using LayoutB = typename B2bImplicitGemmKernel::LayoutB;
64
+ using ElementC = typename B2bImplicitGemmKernel::ElementC;
65
+ using LayoutC = typename B2bImplicitGemmKernel::LayoutC;
66
+ using ElementAccumulator = typename B2bImplicitGemmKernel::ElementAccumulator;
67
+ using ElementCompute = typename B2bImplicitGemmKernel::ElementCompute;
68
+ using ElementScaleBias = typename B2bImplicitGemmKernel::ElementScaleBias;
69
+ using LayoutScaleBias = typename B2bImplicitGemmKernel::LayoutScaleBias;
70
+ using OperatorClass = typename B2bImplicitGemmKernel::OperatorClass;
71
+ using ArchTag = typename B2bImplicitGemmKernel::ArchTag;
72
+ using ThreadblockShape0 = typename B2bImplicitGemmKernel::ThreadblockShape0;
73
+ using ThreadblockShape1 = typename B2bImplicitGemmKernel::ThreadblockShape1;
74
+ using WarpShape0 = typename B2bImplicitGemmKernel::WarpShape0;
75
+ using WarpShape1 = typename B2bImplicitGemmKernel::WarpShape1;
76
+ using InstructionShape = typename B2bImplicitGemmKernel::InstructionShape;
77
+ using ThreadblockSwizzle = typename B2bImplicitGemmKernel::ThreadblockSwizzle;
78
+ using EpilogueOutputOp0 = typename B2bImplicitGemmKernel::EpilogueOutputOp0;
79
+ using EpilogueOutputOp1 = typename B2bImplicitGemmKernel::EpilogueOutputOp1;
80
+ static int const kStages = B2bImplicitGemmKernel::kStages;
81
+ static int const kConvDim = B2bImplicitGemmKernel::kConvDim;
82
+ using WarpMmaOperator0 = typename B2bImplicitGemmKernel::WarpMmaOperator0;
83
+ using WarpMmaOperator1 = typename B2bImplicitGemmKernel::WarpMmaOperator1;
84
+ using ArchMmaOperator = typename B2bImplicitGemmKernel::ArchMmaOperator;
85
+ using MathOperator = typename B2bImplicitGemmKernel::MathOperator;
86
+
87
+ static cutlass::conv::Operator const kConvolutionalOperator = B2bImplicitGemmKernel::kConvolutionalOperator;
88
+ static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = B2bImplicitGemmKernel::kIteratorAlgorithm;
89
+
90
+ static int const kWarpCount =
91
+ (ThreadblockShape0::kM / WarpShape0::kM) *
92
+ (ThreadblockShape0::kN / WarpShape0::kN);
93
+
94
+ /// Argument structure
95
+ using Arguments = typename B2bImplicitGemmKernel::Arguments;
96
+
97
+ private:
98
+
99
+ /// Kernel parameters object
100
+ typename B2bImplicitGemmKernel::Params params_;
101
+
102
+ public:
103
+
104
+ /// Constructs Implicit GEMM
105
+ B2bImplicitGemmConvolution() { }
106
+
107
+ /// Determines whether the Implicit GEMM can execute the given problem.
108
+ static Status can_implement(Arguments const &args) {
109
+
110
+ // dispatch to iterators
111
+ Status status = B2bImplicitGemmKernel::B2bMma::IteratorA0::can_implement(args.problem_size_0);
112
+ if (Status::kSuccess != status) {
113
+ return status;
114
+ }
115
+
116
+ status = B2bImplicitGemmKernel::B2bMma::IteratorB0::can_implement(args.problem_size_0);
117
+ if (Status::kSuccess != status) {
118
+ return status;
119
+ }
120
+
121
+ status = B2bImplicitGemmKernel::B2bMma::IteratorB1::can_implement(args.problem_size_1);
122
+ if (Status::kSuccess != status) {
123
+ return status;
124
+ }
125
+
126
+ // Determine grid shape
127
+ ThreadblockSwizzle threadblock_swizzle;
128
+
129
+ dim3 grid = threadblock_swizzle.get_grid_shape(
130
+ threadblock_swizzle.get_tiled_shape(
131
+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0),
132
+ {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
133
+ args.problem_size_0.split_k_slices));
134
+
135
+ if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
136
+ grid.z <= std::numeric_limits<uint16_t>::max())) {
137
+
138
+ return Status::kErrorInvalidProblem;
139
+ }
140
+
141
+ // Determine if fusion sizes are valid
142
+
143
+ cutlass::gemm::GemmCoord problem_size_0 = implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0);
144
+ cutlass::gemm::GemmCoord problem_size_1 = implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_1);
145
+
146
+ if(problem_size_0.m() != problem_size_1.m())
147
+ return Status::kErrorInvalidProblem;
148
+
149
+ if(problem_size_0.n() != problem_size_1.k())
150
+ return Status::kErrorInvalidProblem;
151
+
152
+ if(args.problem_size_1.R != 1 || args.problem_size_1.S != 1)
153
+ return Status::kErrorInvalidProblem;
154
+
155
+ if(problem_size_0.n() > ThreadblockShape0::kN)
156
+ return Status::kErrorInvalidProblem;
157
+
158
+ if(problem_size_1.n() > ThreadblockShape1::kN)
159
+ return Status::kErrorInvalidProblem;
160
+
161
+ return Status::kSuccess;
162
+ }
163
+
164
+ /// Gets the workspace size
165
+ static size_t get_workspace_size(Arguments const &args) {
166
+
167
+ size_t workspace_bytes = 0;
168
+
169
+ // Determine grid shape
170
+ ThreadblockSwizzle threadblock_swizzle;
171
+
172
+ cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
173
+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0),
174
+ {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
175
+ args.problem_size_0.split_k_slices);
176
+
177
+ if(args.split_k_mode == SplitKMode::kParallel) {
178
+
179
+ // Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace.
180
+ // The user needs to call a reduction operator to obtain the final output tensor
181
+ workspace_bytes =
182
+ sizeof(ElementAccumulator) *
183
+ size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size_0)) *
184
+ size_t(grid_tiled_shape.k());
185
+ }
186
+
187
+ else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size_0.split_k_slices > 1) {
188
+
189
+ // Split-K serial: The user workspace is used to store semaphore and serialize writing the
190
+ // final reduced output to user's output tensor
191
+ workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
192
+ }
193
+
194
+ return workspace_bytes;
195
+ }
196
+
197
+ /// Initializes GEMM state from arguments.
198
+ Status initialize(
199
+ Arguments const &args,
200
+ void *workspace = nullptr,
201
+ cudaStream_t stream = nullptr) {
202
+
203
+ if (args.problem_size_0.split_k_slices > 1) {
204
+
205
+ if (!workspace) {
206
+ return Status::kErrorWorkspaceNull;
207
+ }
208
+
209
+ cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream);
210
+
211
+ if (status != cudaSuccess) {
212
+ return Status::kErrorInternal;
213
+ }
214
+ }
215
+
216
+ // initialize the params structure from the arguments
217
+ params_ = typename B2bImplicitGemmKernel::Params(
218
+ args,
219
+ static_cast<int *>(workspace)
220
+ );
221
+
222
+ int smem_size = int(sizeof(typename B2bImplicitGemmKernel::SharedStorage));
223
+
224
+ if (smem_size >= (48 << 10)) {
225
+ cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel<B2bImplicitGemmKernel>,
226
+ cudaFuncAttributeMaxDynamicSharedMemorySize,
227
+ smem_size);
228
+
229
+ if (result != cudaSuccess) {
230
+ return Status::kErrorInternal;
231
+ }
232
+ }
233
+
234
+ return Status::kSuccess;
235
+ }
236
+
237
+ /// Initializes GEMM state from arguments.
238
+ Status update(Arguments const &args, void *workspace = nullptr) {
239
+
240
+ // update the params structure from the arguments
241
+ params_.ptr_A0 = args.ref_A0.data();
242
+ params_.ptr_B0 = args.ref_B0.data();
243
+ params_.ptr_C0 = args.ref_C0.data();
244
+ params_.ptr_Scale0 = args.ref_Scale0.data();
245
+ params_.ptr_Bias0 = args.ref_Bias0.data();
246
+ params_.ptr_B1 = args.ref_B1.data();
247
+ params_.ptr_C1 = args.ref_C1.data();
248
+ params_.ptr_D1 = args.ref_D1.data();
249
+ params_.output_op_0 = args.output_op_0;
250
+ params_.output_op_1 = args.output_op_1;
251
+ params_.semaphore = static_cast<int *>(workspace);
252
+
253
+ return Status::kSuccess;
254
+ }
255
+
256
+ /// Runs the kernel using initialized state.
257
+ Status run(cudaStream_t stream = nullptr) {
258
+
259
+ ThreadblockSwizzle threadblock_swizzle;
260
+
261
+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
262
+ dim3 block(32 * kWarpCount, 1, 1);
263
+
264
+ int smem_size = int(sizeof(typename B2bImplicitGemmKernel::SharedStorage));
265
+
266
+ cutlass::Kernel<B2bImplicitGemmKernel><<<grid, block, smem_size, stream>>>(params_);
267
+
268
+ cudaError_t result = cudaGetLastError();
269
+
270
+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
271
+ }
272
+
273
+ /// Runs the kernel using initialized state.
274
+ Status operator()(cudaStream_t stream = nullptr) {
275
+ return run(stream);
276
+ }
277
+
278
+ /// Runs the kernel using initialized state.
279
+ Status operator()(
280
+ Arguments const &args,
281
+ void *workspace = nullptr,
282
+ cudaStream_t stream = nullptr) {
283
+
284
+ Status status = initialize(args, workspace, stream);
285
+
286
+ if (status == Status::kSuccess) {
287
+ status = run(stream);
288
+ }
289
+
290
+ return status;
291
+ }
292
+ };
293
+
294
+ /////////////////////////////////////////////////////////////////////////////////////////////////
295
+
296
+ /////////////////////////////////////////////////////////////////////////////////////////////////
297
+ } // namespace device
298
+ } // namespace conv
299
+ } // namespace cutlass
300
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h ADDED
@@ -0,0 +1,811 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+
39
+ #include "cutlass/gemm/gemm.h"
40
+ #include "cutlass/matrix_coord.h"
41
+ #include "cutlass/semaphore.h"
42
+
43
+ #include "kernel/b2b_gemm_grouped_problem_visitor.h"
44
+ #include "threadblock/grouped_threadblock_swizzle.h"
45
+
46
+ /////////////////////////////////////////////////////////////////////////////////////////////////
47
+
48
+ namespace cutlass {
49
+ namespace gemm {
50
+ namespace kernel {
51
+
52
+ namespace detail {
53
+
54
+ /// Utility struct for returning the type of the problem visitor used by the swizzling function,
55
+ /// if it is a grouped swizzling function, or a default visitor. This is used only for defining
56
+ /// the parameters of the problem visitor used in GroupedParams.
57
+ template <
58
+ typename B2bMma_,
59
+ typename ThreadblockSwizzle_,
60
+ typename Enable = void
61
+ >
62
+ struct ProblemVisitorOrDefault;
63
+
64
+ /// Return a generic problem visitor for GEMM problems
65
+ template <
66
+ typename B2bMma_,
67
+ typename ThreadblockSwizzle_
68
+ >
69
+ struct ProblemVisitorOrDefault<B2bMma_,
70
+ ThreadblockSwizzle_,
71
+ typename platform::enable_if<
72
+ ! cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle_>::value
73
+ >::type> {
74
+ using value = B2bGemmGroupedProblemVisitor<typename B2bMma_::Shape,
75
+ GroupScheduleMode::kDeviceOnly,
76
+ 128,
77
+ 128,
78
+ platform::is_same<typename B2bMma_::LayoutC,
79
+ cutlass::layout::ColumnMajor>::value>;
80
+ };
81
+
82
+ /// Return the problem visitor specified by the swizzling function
83
+ template <
84
+ typename B2bMma_,
85
+ typename ThreadblockSwizzle_
86
+ >
87
+ struct ProblemVisitorOrDefault<B2bMma_,
88
+ ThreadblockSwizzle_,
89
+ typename platform::enable_if<
90
+ cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle_>::value
91
+ >::type> {
92
+ using value = typename ThreadblockSwizzle_::ProblemVisitor;
93
+ };
94
+
95
+ } // namespace detail
96
+
97
+ /////////////////////////////////////////////////////////////////////////////////////////////////
98
+
99
+ template <
100
+ typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
101
+ typename Epilogue_, ///! Epilogue
102
+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function
103
+ >
104
+ struct B2bGemm {
105
+
106
+ using B2bMma = B2bMma_;
107
+ using Epilogue = Epilogue_;
108
+ using OutputOp0 = typename B2bMma::OutputOp;
109
+ using OutputOp1 = typename Epilogue::OutputOp;
110
+ using ThreadblockSwizzle = ThreadblockSwizzle_;
111
+
112
+ using ElementA0 = typename B2bMma::IteratorA0::Element;
113
+ using LayoutA0 = typename B2bMma::IteratorA0::Layout;
114
+ using ElementB0 = typename B2bMma::IteratorB0::Element;
115
+ using LayoutB0 = typename B2bMma::IteratorB0::Layout;
116
+ using ElementB1 = typename B2bMma::IteratorB1::Element;
117
+ using LayoutB1 = typename B2bMma::IteratorB1::Layout;
118
+ using ElementC = typename Epilogue::OutputTileIterator::Element;
119
+ using LayoutC = typename Epilogue::OutputTileIterator::Layout;
120
+
121
+ using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element;
122
+
123
+ /// Data types needed for higher-level containers. In some cases, a single type must be exposed
124
+ /// despite the B2b GEMM using two GEMMs under the hood. In such cases, we select the values from
125
+ /// the second GEMM (other than for ElementA/ElementB)
126
+ using ElementA = typename B2bMma::IteratorA0::Element;
127
+ using LayoutA = typename B2bMma::IteratorA0::Layout;
128
+ using ElementB = typename B2bMma::IteratorB0::Element;
129
+ using LayoutB = typename B2bMma::IteratorB0::Layout;
130
+
131
+ static ComplexTransform const kTransformA = B2bMma::kTransformA;
132
+ static ComplexTransform const kTransformB = B2bMma::kTransformB;
133
+ using Operator = typename B2bMma::Operator0;
134
+
135
+ using OperatorClass = typename Operator::OperatorClass;
136
+ using ThreadblockShape = typename B2bMma::Shape0;
137
+ using WarpShape = typename Operator::Shape;
138
+ using InstructionShape = typename Operator::InstructionShape;
139
+ using ArchTag = typename B2bMma::ArchTag;
140
+
141
+ static int const kStages = B2bMma::kStages;
142
+ static int const kAlignmentA = B2bMma::IteratorA::AccessType::kElements;
143
+ static int const kAlignmentB = B2bMma::IteratorB::AccessType::kElements;
144
+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
145
+
146
+ using Mma = B2bMma;
147
+ using EpilogueOutputOp = OutputOp1;
148
+
149
+ /// Warp count (concept: GemmShape)
150
+ using WarpCount0 = typename B2bMma::WarpCount0;
151
+ static int const kThreadCount = 32 * WarpCount0::kCount;
152
+
153
+ /// Argument structure
154
+ struct Arguments {
155
+
156
+ //
157
+ // Data members
158
+ //
159
+
160
+ GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;
161
+ GemmCoord problem_size_0{0,0,0};
162
+ GemmCoord problem_size_1{0,0,0};
163
+ typename B2bMma::IteratorA0::TensorRef ref_A0{};
164
+ typename B2bMma::IteratorB0::TensorRef ref_B0{};
165
+ typename Epilogue::OutputTileIterator::TensorRef ref_C0{};
166
+ typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0{};
167
+ typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0{};
168
+ typename B2bMma::IteratorB1::TensorRef ref_B1{};
169
+ typename Epilogue::OutputTileIterator::TensorRef ref_C1{};
170
+ typename Epilogue::OutputTileIterator::TensorRef ref_D1{};
171
+ int64_t batch_stride_A0{0};
172
+ int64_t batch_stride_B0{0};
173
+ int64_t batch_stride_B1{0};
174
+ int64_t batch_stride_C1{0};
175
+ int64_t batch_stride_D1{0};
176
+ int64_t batch_stride_Bias0{0};
177
+ int64_t batch_stride_Scale0{0};
178
+ typename OutputOp0::Params epilogue0 {};
179
+ typename OutputOp1::Params epilogue1 {};
180
+ int batch_count{1};
181
+
182
+ //
183
+ // Methods
184
+ //
185
+
186
+ /// Default ctor
187
+ Arguments() = default;
188
+
189
+ /// Constructs an Arguments structure
190
+ CUTLASS_HOST_DEVICE
191
+ Arguments(
192
+ GemmUniversalMode mode_,
193
+ GemmCoord problem_size_0_,
194
+ GemmCoord problem_size_1_,
195
+ typename B2bMma::IteratorA0::TensorRef ref_A0_,
196
+ typename B2bMma::IteratorB0::TensorRef ref_B0_,
197
+ typename Epilogue::OutputTileIterator::TensorRef ref_C0_,
198
+ typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0_,
199
+ typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0_,
200
+ typename B2bMma::IteratorB1::TensorRef ref_B1_,
201
+ typename Epilogue::OutputTileIterator::TensorRef ref_C1_,
202
+ typename Epilogue::OutputTileIterator::TensorRef ref_D1_,
203
+ int64_t batch_stride_A0_,
204
+ int64_t batch_stride_B0_,
205
+ int64_t batch_stride_B1_,
206
+ int64_t batch_stride_C1_,
207
+ int64_t batch_stride_D1_,
208
+ int64_t batch_stride_Bias0_,
209
+ int64_t batch_stride_Scale0_,
210
+ typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(),
211
+ typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(),
212
+ int batch_count_ = 1
213
+ ):
214
+ mode(mode_),
215
+ problem_size_0(problem_size_0_),
216
+ problem_size_1(problem_size_1_),
217
+ ref_A0(ref_A0_),
218
+ ref_B0(ref_B0_),
219
+ ref_C0(ref_C0_),
220
+ ref_Scale0(ref_Scale0_),
221
+ ref_Bias0(ref_Bias0_),
222
+ ref_B1(ref_B1_),
223
+ ref_C1(ref_C1_),
224
+ ref_D1(ref_D1_),
225
+ batch_stride_A0(batch_stride_A0_),
226
+ batch_stride_B0(batch_stride_B0_),
227
+ batch_stride_B1(batch_stride_B1_),
228
+ batch_stride_C1(batch_stride_C1_),
229
+ batch_stride_D1(batch_stride_D1_),
230
+ batch_stride_Bias0(batch_stride_Bias0_),
231
+ batch_stride_Scale0(batch_stride_Scale0_),
232
+ epilogue0(epilogue0_),
233
+ epilogue1(epilogue1_),
234
+ batch_count(batch_count_) {
235
+ }
236
+ };
237
+
238
+ // Arguments structure for grouped B2B problems
239
+ struct GroupedArguments {
240
+ GemmCoord* problem_size_0;
241
+ GemmCoord* problem_size_1;
242
+ typename B2bMma::IteratorA0::TensorRef* ref_A0;
243
+ typename B2bMma::IteratorB0::TensorRef* ref_B0;
244
+ typename Epilogue::OutputTileIterator::TensorRef* ref_C0;
245
+ typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0;
246
+ typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0;
247
+ typename B2bMma::IteratorB1::TensorRef* ref_B1;
248
+ typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
249
+ typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
250
+
251
+ // Epilogue params remain constant across all problems in the group. Thus,
252
+ // the parameter here is not a pointer.
253
+ typename OutputOp0::Params epilogue0;
254
+ typename OutputOp1::Params epilogue1;
255
+
256
+ int problem_count;
257
+ int threadblock_count;
258
+ GemmCoord* host_problem_sizes;
259
+
260
+ CUTLASS_HOST_DEVICE
261
+ GroupedArguments(
262
+ int problem_count,
263
+ GemmCoord* problem_size_0_,
264
+ GemmCoord* problem_size_1_,
265
+ typename B2bMma::IteratorA0::TensorRef* ref_A0_,
266
+ typename B2bMma::IteratorB0::TensorRef* ref_B0_,
267
+ typename Epilogue::OutputTileIterator::TensorRef* ref_C0_,
268
+ typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0_,
269
+ typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0_,
270
+ typename B2bMma::IteratorB1::TensorRef* ref_B1_,
271
+ typename Epilogue::OutputTileIterator::TensorRef* ref_C1_,
272
+ typename Epilogue::OutputTileIterator::TensorRef* ref_D1_,
273
+ typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(),
274
+ typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(),
275
+ int threadblock_count = 0
276
+ ) : problem_size_0(problem_size_0_), problem_size_1(problem_size_1_),
277
+ ref_A0(ref_A0_), ref_B0(ref_B0_), ref_C0(ref_C0_),
278
+ ref_Scale0(ref_Scale0_), ref_Bias0(ref_Bias0_), ref_B1(ref_B1_),
279
+ ref_C1(ref_C1_), ref_D1(ref_D1_), epilogue0(epilogue0_), epilogue1(epilogue1_),
280
+ problem_count(problem_count),
281
+ threadblock_count(threadblock_count)
282
+ {}
283
+ };
284
+
285
+ /// Parameters structure
286
+ struct Params {
287
+ cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;
288
+ cutlass::gemm::GemmCoord problem_size_0{};
289
+ cutlass::gemm::GemmCoord problem_size_1{};
290
+ cutlass::gemm::GemmCoord grid_tiled_shape{};
291
+ int swizzle_log_tile{0};
292
+ typename B2bMma::IteratorA0::Params params_A0{};
293
+ typename B2bMma::IteratorA0::TensorRef ref_A0{};
294
+ typename B2bMma::IteratorB0::Params params_B0{};
295
+ typename B2bMma::IteratorB0::TensorRef ref_B0{};
296
+ typename Epilogue::OutputTileIterator::Params params_C0{};
297
+ typename Epilogue::OutputTileIterator::TensorRef ref_C0{};
298
+ typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0{};
299
+ typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0{};
300
+ typename B2bMma::IteratorB1::Params params_B1{};
301
+ typename B2bMma::IteratorB1::TensorRef ref_B1{};
302
+ typename Epilogue::OutputTileIterator::Params params_C1{};
303
+ typename Epilogue::OutputTileIterator::TensorRef ref_C1{};
304
+ typename Epilogue::OutputTileIterator::Params params_D1{};
305
+ typename Epilogue::OutputTileIterator::TensorRef ref_D1{};
306
+ typename OutputOp0::Params output_op_0{};
307
+ typename OutputOp1::Params output_op_1{};
308
+ int64_t batch_stride_A0{0};
309
+ int64_t batch_stride_B0{0};
310
+ int64_t batch_stride_B1{0};
311
+ int64_t batch_stride_C1{0};
312
+ int64_t batch_stride_D1{0};
313
+ int64_t batch_stride_Bias0{0};
314
+ int64_t batch_stride_Scale0{0};
315
+ int *semaphore = nullptr;
316
+ int gemm_k_iterations_0{0};
317
+ int gemm_k_size_0{0};
318
+ int gemm_k_iterations_1{0};
319
+ int gemm_k_size_1{0};
320
+
321
+ //
322
+ // Methods
323
+ //
324
+
325
+ Params() = default;
326
+
327
+ CUTLASS_HOST_DEVICE
328
+ Params(
329
+ cutlass::gemm::GemmUniversalMode mode,
330
+ cutlass::gemm::GemmCoord const & problem_size_0,
331
+ cutlass::gemm::GemmCoord const & problem_size_1,
332
+ cutlass::gemm::GemmCoord const & grid_tiled_shape,
333
+ typename B2bMma::IteratorA0::TensorRef ref_A0,
334
+ typename B2bMma::IteratorB0::TensorRef ref_B0,
335
+ typename Epilogue::OutputTileIterator::TensorRef ref_C0,
336
+ typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0,
337
+ typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0,
338
+ typename B2bMma::IteratorB1::TensorRef ref_B1,
339
+ typename Epilogue::OutputTileIterator::TensorRef ref_C1,
340
+ typename Epilogue::OutputTileIterator::TensorRef ref_D1,
341
+ int64_t batch_stride_A0,
342
+ int64_t batch_stride_B0,
343
+ int64_t batch_stride_B1,
344
+ int64_t batch_stride_C1,
345
+ int64_t batch_stride_D1,
346
+ int64_t batch_stride_Bias0,
347
+ int64_t batch_stride_Scale0,
348
+ typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(),
349
+ typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(),
350
+ int *workspace = nullptr
351
+ ):
352
+ mode(mode),
353
+ problem_size_0(problem_size_0),
354
+ problem_size_1(problem_size_1),
355
+ grid_tiled_shape(grid_tiled_shape),
356
+ swizzle_log_tile(ThreadblockSwizzle::get_log_tile(grid_tiled_shape)),
357
+ params_A0(ref_A0.layout()),
358
+ ref_A0(ref_A0),
359
+ params_B0(ref_B0.layout()),
360
+ ref_B0(ref_B0),
361
+ params_C0(ref_C0.layout()),
362
+ ref_C0(ref_C0),
363
+ ref_Scale0(ref_Scale0),
364
+ ref_Bias0(ref_Bias0),
365
+ params_B1(ref_B1.layout()),
366
+ ref_B1(ref_B1),
367
+ params_C1(ref_C1.layout()),
368
+ ref_C1(ref_C1),
369
+ params_D1(ref_D1.layout()),
370
+ ref_D1(ref_D1),
371
+ batch_stride_A0(batch_stride_A0),
372
+ batch_stride_B0(batch_stride_B0),
373
+ batch_stride_B1(batch_stride_B1),
374
+ batch_stride_C1(batch_stride_C1),
375
+ batch_stride_D1(batch_stride_D1),
376
+ batch_stride_Bias0(batch_stride_Bias0),
377
+ batch_stride_Scale0(batch_stride_Scale0),
378
+ output_op_0(output_op_0),
379
+ output_op_1(output_op_1) {
380
+
381
+ int total_gemm_k_iterations_0 = (problem_size_0.k() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK;
382
+ int gemm_k_iterations_0 = (total_gemm_k_iterations_0 + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
383
+ gemm_k_size_0 = gemm_k_iterations_0 * B2bMma::Shape0::kK;
384
+ int total_gemm_k_iterations_1 = (problem_size_1.k() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
385
+ int gemm_k_iterations_1 = (total_gemm_k_iterations_1 + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
386
+ gemm_k_size_1 = gemm_k_iterations_1 * B2bMma::Shape1::kK;
387
+
388
+ semaphore = workspace;
389
+ }
390
+ };
391
+
392
+ struct GroupedParams {
393
+ cutlass::gemm::GemmCoord* problem_size_0;
394
+ cutlass::gemm::GemmCoord* problem_size_1;
395
+ cutlass::gemm::GemmCoord* grid_tiled_shape;
396
+ typename B2bMma::IteratorA0::TensorRef* ref_A0;
397
+ typename B2bMma::IteratorB0::TensorRef* ref_B0;
398
+ typename Epilogue::OutputTileIterator::TensorRef* ref_C0;
399
+ typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0;
400
+ typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0;
401
+ typename B2bMma::IteratorB1::TensorRef* ref_B1;
402
+ typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
403
+ typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
404
+
405
+ // Epilogue params remain constant across all problems in the group. Thus,
406
+ // the parameter here is not a pointer.
407
+ typename OutputOp0::Params output_op_0;
408
+ typename OutputOp1::Params output_op_1;
409
+
410
+ using ProblemVisitor = typename detail::ProblemVisitorOrDefault<B2bMma, ThreadblockSwizzle>::value;
411
+ typename ProblemVisitor::Params problem_visitor;
412
+ int threadblock_count;
413
+ int* workspace;
414
+
415
+ CUTLASS_HOST_DEVICE
416
+ GroupedParams() {}
417
+
418
+ CUTLASS_HOST_DEVICE
419
+ GroupedParams(
420
+ GroupedArguments const &args,
421
+ void *workspace = nullptr,
422
+ int tile_count = 0
423
+ ) :
424
+ problem_size_0(args.problem_size_0), problem_size_1(args.problem_size_1),
425
+ ref_A0(args.ref_A0), ref_B0(args.ref_B0), ref_C0(args.ref_C0),
426
+ ref_Scale0(args.ref_Scale0), ref_Bias0(args.ref_Bias0), ref_B1(args.ref_B1), ref_C1(args.ref_C1), ref_D1(args.ref_D1),
427
+ output_op_0(args.epilogue0), output_op_1(args.epilogue1),
428
+ problem_visitor(args.problem_size_0, args.problem_size_1, args.problem_count, workspace, tile_count),
429
+ threadblock_count(args.threadblock_count),
430
+ workspace(reinterpret_cast<int*>(workspace)) {}
431
+
432
+ CUTLASS_HOST_DEVICE
433
+ void transpose() {
434
+ // Only row-major outputs are currently supported, so no transpose is performed
435
+ }
436
+
437
+ /// Returns non-grouped parameters to be used as input to the kernel-level
438
+ /// operator for the problem indicated by problem_visitor.
439
+ CUTLASS_HOST_DEVICE
440
+ Params to_single_params(const ProblemVisitor& problem_visitor) const {
441
+ GemmCoord problem_size0 = problem_visitor.problem_size0();
442
+ GemmCoord problem_size1 = problem_visitor.problem_size1();
443
+ int32_t idx = problem_visitor.problem_index();
444
+ GemmCoord grid_shape = problem_visitor.grid_shape(problem_size1);
445
+
446
+ return Params(
447
+ cutlass::gemm::GemmUniversalMode::kGemm,
448
+ problem_size0,
449
+ problem_size1,
450
+ grid_shape,
451
+ ref_A0[idx],
452
+ ref_B0[idx],
453
+ ref_C0[idx],
454
+ ref_Scale0[idx],
455
+ ref_Bias0[idx],
456
+ ref_B1[idx],
457
+ ref_C1[idx],
458
+ ref_D1[idx],
459
+ 0, 0, 0, 0, 0, 0, 0, // Batched B2B GEMMs within the grouped kernel are currently unsupported
460
+ output_op_0,
461
+ output_op_1,
462
+ workspace
463
+ );
464
+ }
465
+ };
466
+
467
+ /// Shared memory storage structure
468
+ union SharedStorage {
469
+ typename B2bMma::B2bMmaSharedStorage main_loop;
470
+ typename Epilogue::SharedStorage epilogue;
471
+ };
472
+
473
+ //
474
+ // Methods
475
+ //
476
+
477
+ CUTLASS_HOST_DEVICE
478
+ B2bGemm() { }
479
+
480
+ /// Determines whether kernel satisfies alignment
481
+ static Status can_implement(
482
+ cutlass::gemm::GemmCoord const & problem_size_0,
483
+ cutlass::gemm::GemmCoord const & problem_size_1,
484
+ typename B2bMma::IteratorA0::TensorRef ref_A0,
485
+ typename B2bMma::IteratorB0::TensorRef ref_B0,
486
+ typename Epilogue::OutputTileIterator::TensorRef ref_C0,
487
+ typename B2bMma::IteratorB1::TensorRef ref_B1,
488
+ typename Epilogue::OutputTileIterator::TensorRef ref_C1,
489
+ typename Epilogue::OutputTileIterator::TensorRef ref_D1) {
490
+
491
+ static int const kAlignmentA = B2bMma::IteratorA0::AccessType::kElements;
492
+ static int const kAlignmentB = B2bMma::IteratorB0::AccessType::kElements;
493
+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
494
+
495
+ if (!TensorRef_aligned(ref_A0, kAlignmentA)) {
496
+ return Status::kErrorMisalignedOperand;
497
+ }
498
+
499
+ if (!TensorRef_aligned(ref_B0, kAlignmentB)) {
500
+ return Status::kErrorMisalignedOperand;
501
+ }
502
+
503
+ if (!TensorRef_aligned(ref_C0, kAlignmentC)) {
504
+ return Status::kErrorMisalignedOperand;
505
+ }
506
+
507
+ if (!TensorRef_aligned(ref_B1, kAlignmentB)) {
508
+ return Status::kErrorMisalignedOperand;
509
+ }
510
+
511
+ if (!TensorRef_aligned(ref_C1, kAlignmentC)) {
512
+ return Status::kErrorMisalignedOperand;
513
+ }
514
+
515
+ if (!TensorRef_aligned(ref_D1, kAlignmentC)) {
516
+ return Status::kErrorMisalignedOperand;
517
+ }
518
+
519
+ if ((problem_size_0.m() % kAlignmentA) || (problem_size_0.k() % kAlignmentA) ||
520
+ (problem_size_0.n() % kAlignmentB) || (problem_size_0.k() % kAlignmentB) ||
521
+ (problem_size_0.m() % kAlignmentC) || (problem_size_0.n() % kAlignmentC) ||
522
+ (problem_size_1.m() % kAlignmentA) || (problem_size_1.k() % kAlignmentA) ||
523
+ (problem_size_1.n() % kAlignmentB) || (problem_size_1.k() % kAlignmentB) ||
524
+ (problem_size_1.m() % kAlignmentC) || (problem_size_1.n() % kAlignmentC)) {
525
+
526
+ return Status::kErrorMisalignedOperand;
527
+ }
528
+
529
+ // Determine if fusion sizes are valid
530
+ if(problem_size_0.m() != problem_size_1.m())
531
+ return Status::kErrorInvalidProblem;
532
+
533
+ if(problem_size_0.n() != problem_size_1.k())
534
+ return Status::kErrorInvalidProblem;
535
+
536
+ if(problem_size_0.n() > B2bMma::Shape0::kN)
537
+ return Status::kErrorInvalidProblem;
538
+
539
+ if(problem_size_1.n() > B2bMma::Shape1::kN)
540
+ return Status::kErrorInvalidProblem;
541
+
542
+ return Status::kSuccess;
543
+ }
544
+
545
+ /// Executes one GEMM
546
+ CUTLASS_DEVICE
547
+ void operator()(Params const &params, SharedStorage &shared_storage) {
548
+ ThreadblockSwizzle threadblock_swizzle;
549
+ run_with_swizzle(params, shared_storage, threadblock_swizzle);
550
+ }
551
+
552
+ /// Executes one GEMM with an externally-provided swizzling function
553
+ CUTLASS_DEVICE
554
+ void run_with_swizzle(Params const &params, SharedStorage &shared_storage, ThreadblockSwizzle& threadblock_swizzle) {
555
+
556
+ cutlass::gemm::GemmCoord threadblock_tile_offset =
557
+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
558
+
559
+ // Early exit if CTA is out of range
560
+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
561
+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
562
+
563
+ return;
564
+ }
565
+
566
+ ElementA0 *ptr_A0 = static_cast<ElementA0 *>(params.ref_A0.data());
567
+ ElementB0 *ptr_B0 = static_cast<ElementB0 *>(params.ref_B0.data());
568
+ ElementB1 *ptr_B1 = static_cast<ElementB1 *>(params.ref_B1.data());
569
+
570
+ ScaleBiasData *ptr_Bias0 = static_cast<ScaleBiasData *>(params.ref_Bias0.data());
571
+ ScaleBiasData *ptr_Scale0 = static_cast<ScaleBiasData *>(params.ref_Scale0.data());
572
+
573
+ int offset_k_0 = 0;
574
+ int offset_k_1 = 0;
575
+
576
+ int problem_size_k_0 = params.problem_size_0.k();
577
+ int problem_size_k_1 = params.problem_size_1.k();
578
+
579
+ if (params.mode == GemmUniversalMode::kGemm) {
580
+
581
+ // Problem size is a function of threadblock index in the K dimension
582
+ problem_size_k_0 = min(
583
+ problem_size_k_0,
584
+ (threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
585
+
586
+ // Problem size is a function of threadblock index in the K dimension
587
+ problem_size_k_1 = min(
588
+ problem_size_k_1,
589
+ (threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
590
+
591
+ offset_k_0 = threadblock_tile_offset.k() * params.gemm_k_size_0;
592
+ offset_k_1 = threadblock_tile_offset.k() * params.gemm_k_size_1;
593
+ }
594
+
595
+ else if (params.mode == GemmUniversalMode::kBatched) {
596
+ ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A0;
597
+ ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0;
598
+ ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1;
599
+ ptr_Bias0 += threadblock_tile_offset.k() * params.batch_stride_Bias0;
600
+ ptr_Scale0 += threadblock_tile_offset.k() * params.batch_stride_Scale0;
601
+ }
602
+
603
+ // Compute initial location in logical coordinates
604
+ cutlass::MatrixCoord tb_offset_A0{
605
+ threadblock_tile_offset.m() * B2bMma::Shape0::kM,
606
+ offset_k_0,
607
+ };
608
+
609
+ cutlass::MatrixCoord tb_offset_B0{
610
+ offset_k_0,
611
+ threadblock_tile_offset.n() * B2bMma::Shape0::kN
612
+ };
613
+
614
+ cutlass::MatrixCoord tb_offset_B1{
615
+ offset_k_1,
616
+ threadblock_tile_offset.n() * B2bMma::Shape1::kN
617
+ };
618
+
619
+ // Compute threadblock-scoped matrix multiply-add
620
+ int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK;
621
+
622
+ // Compute threadblock-scoped matrix multiply-add
623
+ // int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
624
+
625
+
626
+ // Compute position within threadblock
627
+ int thread_idx = threadIdx.x;
628
+
629
+ // Construct iterators to A and B operands
630
+ typename B2bMma::IteratorA0 iterator_A0(
631
+ params.params_A0,
632
+ ptr_A0,
633
+ {params.problem_size_0.m(), problem_size_k_0},
634
+ thread_idx,
635
+ tb_offset_A0);
636
+
637
+ typename B2bMma::IteratorB0 iterator_B0(
638
+ params.params_B0,
639
+ ptr_B0,
640
+ {problem_size_k_0, params.problem_size_0.n()},
641
+ thread_idx,
642
+ tb_offset_B0);
643
+
644
+ typename B2bMma::IteratorB1 iterator_B1(
645
+ params.params_B1,
646
+ ptr_B1,
647
+ {problem_size_k_1, params.problem_size_1.n()},
648
+ thread_idx,
649
+ tb_offset_B1);
650
+
651
+ // Broadcast the warp_id computed by lane 0 to ensure dependent code
652
+ // is compiled as warp-uniform.
653
+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
654
+ int lane_idx = threadIdx.x % 32;
655
+
656
+ // Construct iterators to accumulator scale/bias vector
657
+ typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0(
658
+ ptr_Scale0,
659
+ {1, params.problem_size_0.n()},
660
+ thread_idx,
661
+ warp_idx,
662
+ MatrixCoord(
663
+ 0, threadblock_tile_offset.n() * B2bMma::Shape0::kN
664
+ )
665
+ );
666
+
667
+ typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0(
668
+ ptr_Bias0,
669
+ {1, params.problem_size_0.n()},
670
+ thread_idx,
671
+ warp_idx,
672
+ MatrixCoord(
673
+ 0, threadblock_tile_offset.n() * B2bMma::Shape0::kN
674
+ )
675
+ );
676
+
677
+ //
678
+ // Main loop
679
+ //
680
+
681
+ OutputOp0 output_op_0(params.output_op_0);
682
+
683
+ if (cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle>::value) {
684
+ // Wait for all threads to finish their epilogue phases from the previous tile.
685
+ __syncthreads();
686
+ }
687
+
688
+ // Construct thread-scoped matrix multiply
689
+ B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, params.problem_size_0.n());
690
+
691
+ typename B2bMma::FragmentC0 src_accum;
692
+ typename B2bMma::FragmentC1 accumulators;
693
+
694
+ src_accum.clear();
695
+ accumulators.clear();
696
+
697
+ // Compute threadblock-scoped matrix multiply-add
698
+ b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
699
+ iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
700
+
701
+ //
702
+ // Epilogue
703
+ //
704
+
705
+ OutputOp1 output_op_1(params.output_op_1);
706
+
707
+ //
708
+ // Masked tile iterators constructed from members
709
+ //
710
+
711
+ threadblock_tile_offset =
712
+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
713
+
714
+ //assume identity swizzle
715
+ MatrixCoord threadblock_offset(
716
+ threadblock_tile_offset.m() * B2bMma::Shape1::kM,
717
+ threadblock_tile_offset.n() * B2bMma::Shape1::kN
718
+ );
719
+
720
+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
721
+
722
+ ElementC *ptr_C1 = static_cast<ElementC *>(params.ref_C1.data());
723
+ ElementC *ptr_D1 = static_cast<ElementC *>(params.ref_D1.data());
724
+
725
+ // Construct the semaphore.
726
+ Semaphore semaphore(params.semaphore + block_idx, thread_idx);
727
+
728
+ if (params.mode == GemmUniversalMode::kGemm) {
729
+ // If performing a reduction via split-K, fetch the initial synchronization
730
+
731
+ if (params.grid_tiled_shape.k() > 1) {
732
+ // Fetch the synchronization lock initially but do not block.
733
+ semaphore.fetch();
734
+
735
+ // Indicate which position in a serial reduction the output operator is currently updating
736
+ output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
737
+ }
738
+ }
739
+ else if (params.mode == GemmUniversalMode::kBatched) {
740
+ ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1;
741
+ ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D1;
742
+ }
743
+
744
+ // Tile iterator loading from source tensor.
745
+ typename Epilogue::OutputTileIterator iterator_C1(
746
+ params.params_C1,
747
+ ptr_C1,
748
+ params.problem_size_1.mn(),
749
+ thread_idx,
750
+ threadblock_offset
751
+ );
752
+
753
+ // Tile iterator writing to destination tensor.
754
+ typename Epilogue::OutputTileIterator iterator_D1(
755
+ params.params_D1,
756
+ ptr_D1,
757
+ params.problem_size_1.mn(),
758
+ thread_idx,
759
+ threadblock_offset
760
+ );
761
+
762
+ Epilogue epilogue(
763
+ shared_storage.epilogue,
764
+ thread_idx,
765
+ warp_idx,
766
+ lane_idx);
767
+
768
+ // Wait on the semaphore - this latency may have been covered by iterator construction
769
+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
770
+
771
+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor.
772
+ if (threadblock_tile_offset.k()) {
773
+ iterator_C1 = iterator_D1;
774
+ }
775
+
776
+ semaphore.wait(threadblock_tile_offset.k());
777
+
778
+ __threadfence();
779
+ }
780
+
781
+ // Execute the epilogue operator to update the destination tensor.
782
+ epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
783
+
784
+ //
785
+ // Release the semaphore
786
+ //
787
+
788
+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
789
+
790
+ int lock = 0;
791
+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
792
+
793
+ // The final threadblock resets the semaphore for subsequent grids.
794
+ lock = 0;
795
+ }
796
+ else {
797
+ // Otherwise, the semaphore is incremented
798
+ lock = threadblock_tile_offset.k() + 1;
799
+ }
800
+
801
+ __threadfence();
802
+ semaphore.release(lock);
803
+ }
804
+ }
805
+ };
806
+
807
+ /////////////////////////////////////////////////////////////////////////////////////////////////
808
+
809
+ } // namespace kernel
810
+ } // namespace gemm
811
+ } // namespace cutlass
build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm_grouped_problem_visitor.h ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief Scheduler for grouped B2b GEMMs
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/cutlass.h"
39
+ #include "cutlass/gemm/gemm.h"
40
+ #include "cutlass/matrix_coord.h"
41
+ #include "cutlass/gemm/kernel/grouped_problem_visitor.h"
42
+ #include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
43
+
44
+ /////////////////////////////////////////////////////////////////////////////////////////////////
45
+
46
+ namespace cutlass {
47
+ namespace gemm {
48
+ namespace kernel {
49
+
50
+ /////////////////////////////////////////////////////////////////////////////////////////////////
51
+
52
+ /// Visitor class to abstract away the algorithm for iterating over tiles
53
+ template <typename ThreadblockShape,
54
+ GroupScheduleMode GroupScheduleMode_,
55
+ int PrefetchTileCount,
56
+ int ThreadCount,
57
+ bool Transposed = false>
58
+ struct B2bGemmGroupedProblemVisitor : public GroupedProblemVisitor<
59
+ detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>,
60
+ ThreadblockShape,
61
+ GroupScheduleMode_,
62
+ PrefetchTileCount,
63
+ ThreadCount> {
64
+
65
+ using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>;
66
+ using Base = GroupedProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, ThreadCount>;
67
+ using BaseParams = typename Base::Params;
68
+ using SharedStorage = typename Base::SharedStorage;
69
+ static bool const kTransposed = Transposed;
70
+
71
+ cutlass::gemm::GemmCoord const *problem_sizes0;
72
+ cutlass::gemm::GemmCoord const *problem_sizes1;
73
+
74
+ struct Params {
75
+ cutlass::gemm::GemmCoord const *problem_sizes0;
76
+ cutlass::gemm::GemmCoord const *problem_sizes1;
77
+ int32_t problem_count;
78
+ void const *workspace;
79
+ int32_t tile_count;
80
+
81
+ //
82
+ // Methods
83
+ //
84
+
85
+ /// Ctor
86
+ CUTLASS_HOST_DEVICE
87
+ Params(): problem_sizes0(nullptr), problem_sizes1(nullptr),
88
+ problem_count(0), workspace(nullptr), tile_count(0) { }
89
+
90
+ /// Ctor
91
+ CUTLASS_HOST_DEVICE
92
+ Params(
93
+ cutlass::gemm::GemmCoord const *problem_sizes0,
94
+ cutlass::gemm::GemmCoord const *problem_sizes1,
95
+ int32_t problem_count,
96
+ void const *workspace = nullptr,
97
+ int32_t tile_count = 0
98
+ ):
99
+ problem_sizes0(problem_sizes0),
100
+ problem_sizes1(problem_sizes1),
101
+ problem_count(problem_count),
102
+ workspace(workspace),
103
+ tile_count(tile_count)
104
+ {}
105
+
106
+ /// Convert the B2b-GEMM-specific parameters to those used by the base class
107
+ CUTLASS_HOST_DEVICE
108
+ BaseParams to_base() const {
109
+ return BaseParams(// Set problem_sizes as problem_sizes0 because these determine
110
+ // shape of the grid used in the non-grouped B2b GEMM
111
+ problem_sizes0,
112
+ problem_count,
113
+ workspace,
114
+ tile_count);
115
+ }
116
+
117
+ };
118
+
119
+ //
120
+ // Methods
121
+ //
122
+ CUTLASS_DEVICE
123
+ B2bGemmGroupedProblemVisitor(
124
+ Params const &params_,
125
+ SharedStorage &shared_storage_,
126
+ int32_t block_idx
127
+ ): Base (
128
+ params_.to_base(),
129
+ shared_storage_, block_idx),
130
+ problem_sizes0(params_.problem_sizes0),
131
+ problem_sizes1(params_.problem_sizes1)
132
+ {}
133
+
134
+ /// Returns the problem size 0 for the current problem
135
+ CUTLASS_HOST_DEVICE
136
+ cutlass::gemm::GemmCoord problem_size0() const {
137
+ GemmCoord problem = problem_sizes0[this->problem_idx];
138
+ ProblemSizeHelper::possibly_transpose_problem(problem);
139
+ return problem;
140
+ }
141
+
142
+ /// Returns the problem size 1 for the current problem
143
+ CUTLASS_HOST_DEVICE
144
+ cutlass::gemm::GemmCoord problem_size1() const {
145
+ GemmCoord problem = problem_sizes1[this->problem_idx];
146
+ ProblemSizeHelper::possibly_transpose_problem(problem);
147
+ return problem;
148
+ }
149
+ };
150
+
151
+ /////////////////////////////////////////////////////////////////////////////////////////////////
152
+
153
+ } // namespace kernel
154
+ } // namespace gemm
155
+ } // namespace cutlass
156
+
157
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Template for a pipelined Implicit GEMM kernel.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+
39
+ #include "cutlass/aligned_buffer.h"
40
+ #include "cutlass/array.h"
41
+ #include "cutlass/numeric_types.h"
42
+ #include "cutlass/matrix_shape.h"
43
+ #include "cutlass/semaphore.h"
44
+ #include "cutlass/tensor_ref.h"
45
+ #include "cutlass/layout/tensor.h"
46
+ #include "cutlass/gemm/gemm.h"
47
+ #include "cutlass/conv/convolution.h"
48
+ #include "cutlass/conv/conv2d_problem_size.h"
49
+ #include "cutlass/conv/conv3d_problem_size.h"
50
+ #include "cutlass/epilogue/threadblock/output_iterator_parameter.h"
51
+
52
+ /////////////////////////////////////////////////////////////////////////////////////////////////
53
+
54
+ namespace cutlass {
55
+ namespace conv {
56
+ namespace kernel {
57
+
58
+ /////////////////////////////////////////////////////////////////////////////////////////////////
59
+
60
+ template <
61
+ typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
62
+ typename Epilogue_, ///! Epilogue
63
+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function
64
+ conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad)
65
+ typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem
66
+ >
67
+ struct B2bImplicitGemmConvolution {
68
+
69
+ using B2bMma = B2bMma_;
70
+ using Epilogue = Epilogue_;
71
+ using EpilogueOutputOp0 = typename B2bMma::OutputOp;
72
+ using EpilogueOutputOp1 = typename Epilogue::OutputOp;
73
+ using ThreadblockSwizzle = ThreadblockSwizzle_;
74
+ static Operator const kConvolutionalOperator = ConvOperator;
75
+
76
+ using ElementA = typename B2bMma::IteratorA0::Element;
77
+ using LayoutA = typename B2bMma::IteratorA0::Layout;
78
+ using ElementB = typename B2bMma::IteratorB0::Element;
79
+ using LayoutB = typename B2bMma::IteratorB0::Layout;
80
+ using ElementC = typename EpilogueOutputOp1::ElementOutput;
81
+
82
+ /// Set output tensor C layout
83
+ using LayoutC = LayoutA;
84
+
85
+ using ElementAccumulator = typename EpilogueOutputOp0::ElementAccumulator;
86
+ using ElementCompute = typename EpilogueOutputOp0::ElementCompute;
87
+
88
+ /// Scale and Bias
89
+ using ElementScaleBias = typename B2bMma::IteratorAccumulatorScaleBias::Element;
90
+ using LayoutScaleBias = typename B2bMma::IteratorAccumulatorScaleBias::Layout;
91
+
92
+ using WarpMmaOperator0 = typename B2bMma::Policy0::Operator;
93
+ using WarpMmaOperator1 = typename B2bMma::Policy1::Operator;
94
+
95
+ using ArchMmaOperator = typename WarpMmaOperator0::ArchMmaOperator;
96
+ using MathOperator = typename ArchMmaOperator::Operator;
97
+
98
+ using OperatorClass = typename WarpMmaOperator0::OperatorClass;
99
+ using ArchTag = typename WarpMmaOperator0::ArchTag;
100
+
101
+ using ThreadblockShape0 = typename B2bMma::Shape0;
102
+ using ThreadblockShape1 = typename B2bMma::Shape1;
103
+ using WarpShape0 = typename WarpMmaOperator0::Shape;
104
+ using WarpShape1 = typename WarpMmaOperator1::Shape;
105
+ using InstructionShape = typename ArchMmaOperator::Shape;
106
+
107
+ static int const kStages = B2bMma::kStages;
108
+ static IteratorAlgorithm const kIteratorAlgorithm = B2bMma::IteratorA0::kIteratorAlgorithm;
109
+
110
+ /// Warp count (concept: GemmShape)
111
+ using WarpCount0 = typename B2bMma::WarpCount0;
112
+ static int const kThreadCount = 32 * WarpCount0::kCount;
113
+
114
+ using TensorRefA0 = typename B2bMma::IteratorA0::TensorRef;
115
+ using TensorRefB0 = typename B2bMma::IteratorB0::TensorRef;
116
+ using TensorRefScaleBias0 = typename B2bMma::IteratorAccumulatorScaleBias::TensorRef;
117
+ using TensorRefB1 = typename B2bMma::IteratorB1::TensorRef;
118
+ using TensorRefC = cutlass::TensorRef<ElementC, LayoutC>;
119
+
120
+ /// Check iterator A and B convolution dimension are the same and
121
+ // set device::B2bImplicitGemmConvolution::kConvDim
122
+ static_assert(B2bMma::IteratorA0::kConvDim == B2bMma::IteratorB0::kConvDim,
123
+ "Convolution on different dimensions is not supported");
124
+ static int const kConvDim = B2bMma::IteratorA0::kConvDim;
125
+
126
+ /// Conv dimension and problem size structure (Conv2d or Conv3d)
127
+ using ConvProblemSize = ConvProblemSize_;
128
+
129
+ /// Wgrad C stride idx for implicit gemm algorithm
130
+ // Conv2d row-major matrix C (KxRSC)
131
+ // Conv3d row-major matrix C (KxTRSC)
132
+ static int const kWgradCStrideIdx =
133
+ cutlass::platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
134
+
135
+ /// This chooses the appropriate stride element of the C tensor.
136
+ static int const kTensorCStrideIdx =
137
+ (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0);
138
+
139
+ //
140
+ //
141
+ //
142
+ using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter<
143
+ LayoutC,
144
+ typename Epilogue::OutputTileIterator::Layout,
145
+ TensorRefC,
146
+ ConvOperator,
147
+ ConvProblemSize
148
+ >;
149
+
150
+ /// Argument structure
151
+ struct Arguments {
152
+
153
+ //
154
+ // Data members
155
+ //
156
+
157
+ ConvProblemSize problem_size_0;
158
+ ConvProblemSize problem_size_1;
159
+ TensorRefA0 ref_A0;
160
+ TensorRefB0 ref_B0;
161
+ TensorRefC ref_C0;
162
+ TensorRefScaleBias0 ref_Scale0;
163
+ TensorRefScaleBias0 ref_Bias0;
164
+ TensorRefB1 ref_B1;
165
+ TensorRefC ref_C1;
166
+ TensorRefC ref_D1;
167
+ typename EpilogueOutputOp0::Params output_op_0;
168
+ typename EpilogueOutputOp1::Params output_op_1;
169
+ SplitKMode split_k_mode;
170
+
171
+ //
172
+ // Methods
173
+ //
174
+
175
+ /// Default ctor
176
+ CUTLASS_HOST_DEVICE
177
+ Arguments() { }
178
+
179
+ CUTLASS_HOST_DEVICE
180
+ Arguments(
181
+ ConvProblemSize const & problem_size_0,
182
+ ConvProblemSize const & problem_size_1
183
+ ):
184
+ problem_size_0(problem_size_0),
185
+ problem_size_1(problem_size_1) { }
186
+
187
+ CUTLASS_HOST_DEVICE
188
+ Arguments(
189
+ ConvProblemSize const & problem_size_0,
190
+ ConvProblemSize const & problem_size_1,
191
+ TensorRefA0 const & ref_A0,
192
+ TensorRefB0 const & ref_B0,
193
+ TensorRefC const & ref_C0,
194
+ TensorRefScaleBias0 const & ref_Scale0,
195
+ TensorRefScaleBias0 const & ref_Bias0,
196
+ TensorRefB1 const & ref_B1,
197
+ TensorRefC const & ref_C1,
198
+ TensorRefC const & ref_D1,
199
+ typename EpilogueOutputOp0::Params const & output_op_0,
200
+ typename EpilogueOutputOp1::Params const & output_op_1,
201
+ SplitKMode const & split_k_mode = SplitKMode::kSerial
202
+ ):
203
+ problem_size_0(problem_size_0),
204
+ problem_size_1(problem_size_1),
205
+ ref_A0(ref_A0),
206
+ ref_B0(ref_B0),
207
+ ref_C0(ref_C0),
208
+ ref_Scale0(ref_Scale0),
209
+ ref_Bias0(ref_Bias0),
210
+ ref_B1(ref_B1),
211
+ ref_C1(ref_C1),
212
+ ref_D1(ref_D1),
213
+ output_op_0(output_op_0),
214
+ output_op_1(output_op_1),
215
+ split_k_mode(split_k_mode)
216
+ {
217
+
218
+ }
219
+
220
+ };
221
+
222
+ /// Parameters structure
223
+ struct Params {
224
+ ConvProblemSize problem_size_0;
225
+ ConvProblemSize problem_size_1;
226
+ cutlass::gemm::GemmCoord grid_tiled_shape;
227
+ gemm::GemmCoord implicit_gemm_problem_size_0;
228
+ gemm::GemmCoord implicit_gemm_problem_size_1;
229
+ int swizzle_log_tile;
230
+ int gemm_k_iterations_0;
231
+ int gemm_k_iterations_1;
232
+ typename B2bMma::IteratorA0::Params iterator_A0;
233
+ typename B2bMma::IteratorA0::Element const *ptr_A0;
234
+ typename B2bMma::IteratorB0::Params iterator_B0;
235
+ typename B2bMma::IteratorB0::Element const *ptr_B0;
236
+ typename Epilogue::OutputTileIterator::Params iterator_C0;
237
+ typename Epilogue::OutputTileIterator::Element *ptr_C0;
238
+ typename B2bMma::IteratorAccumulatorScaleBias::Element *ptr_Scale0;
239
+ typename B2bMma::IteratorAccumulatorScaleBias::Element *ptr_Bias0;
240
+ typename B2bMma::IteratorB1::Params iterator_B1;
241
+ typename B2bMma::IteratorB1::Element const *ptr_B1;
242
+ typename Epilogue::OutputTileIterator::Params iterator_C1;
243
+ typename Epilogue::OutputTileIterator::Element *ptr_C1;
244
+ typename Epilogue::OutputTileIterator::Params iterator_D1;
245
+ typename Epilogue::OutputTileIterator::Element *ptr_D1;
246
+ typename EpilogueOutputOp0::Params output_op_0;
247
+ typename EpilogueOutputOp1::Params output_op_1;
248
+ int *semaphore;
249
+ SplitKMode split_k_mode;
250
+
251
+ //
252
+ // Methods
253
+ //
254
+
255
+ CUTLASS_HOST_DEVICE
256
+ Params(): swizzle_log_tile(0), gemm_k_iterations_0(0), gemm_k_iterations_1(0) { }
257
+
258
+ ///
259
+ CUTLASS_HOST_DEVICE
260
+ Params(
261
+ Arguments const &args,
262
+ int *semaphore = nullptr
263
+ ):
264
+ problem_size_0(args.problem_size_0),
265
+ problem_size_1(args.problem_size_1),
266
+ implicit_gemm_problem_size_0(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0)),
267
+ implicit_gemm_problem_size_1(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_1)),
268
+ iterator_A0(B2bMma::IteratorA0::getParams(args.problem_size_0, args.ref_A0.layout())),
269
+ ptr_A0(args.ref_A0.data()),
270
+ iterator_B0(args.problem_size_0, args.ref_B0.layout()),
271
+ ptr_B0(args.ref_B0.data()),
272
+ iterator_C0(ConvOutputIteratorParameter::layout(args.ref_C0)),
273
+ ptr_C0(args.ref_C0.data()),
274
+ ptr_Scale0(args.ref_Scale0.data()),
275
+ ptr_Bias0(args.ref_Bias0.data()),
276
+ iterator_B1(args.problem_size_1, args.ref_B1.layout()),
277
+ ptr_B1(args.ref_B1.data()),
278
+ iterator_C1(ConvOutputIteratorParameter::layout(args.ref_C1)),
279
+ ptr_C1(args.ref_C1.data()),
280
+ iterator_D1(ConvOutputIteratorParameter::layout(args.ref_D1)),
281
+ ptr_D1(args.ref_D1.data()),
282
+ output_op_0(args.output_op_0),
283
+ output_op_1(args.output_op_1),
284
+ semaphore(semaphore),
285
+ split_k_mode(args.split_k_mode)
286
+ {
287
+ gemm_k_iterations_0 = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape0::kK, args.problem_size_0);
288
+ gemm_k_iterations_1 = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape1::kK, args.problem_size_1);
289
+
290
+ ThreadblockSwizzle threadblock_swizzle;
291
+
292
+ grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
293
+ implicit_gemm_problem_size_0,
294
+ {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
295
+ args.problem_size_0.split_k_slices);
296
+
297
+ swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape);
298
+ }
299
+ };
300
+
301
+ /// Shared memory storage structure
302
+ union SharedStorage {
303
+ typename B2bMma::B2bMmaSharedStorage main_loop;
304
+ typename Epilogue::SharedStorage epilogue;
305
+ };
306
+
307
+ //
308
+ // Methods
309
+ //
310
+
311
+ CUTLASS_HOST_DEVICE
312
+ B2bImplicitGemmConvolution() { }
313
+
314
+ /// Executes one ImplicitGEMM
315
+ CUTLASS_DEVICE
316
+ void operator()(Params const &params, SharedStorage &shared_storage) {
317
+
318
+ // Compute threadblock location
319
+ ThreadblockSwizzle threadblock_swizzle;
320
+
321
+ cutlass::gemm::GemmCoord threadblock_tile_idx =
322
+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
323
+
324
+ // Early exit if CTA is out of range
325
+ if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() ||
326
+ params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) {
327
+
328
+ return;
329
+ }
330
+
331
+ // Compute position within threadblock
332
+ int thread_idx = threadIdx.x;
333
+
334
+ // Construct iterators to A and B operands
335
+ typename B2bMma::IteratorA0 iterator_A0(
336
+ params.iterator_A0,
337
+ params.problem_size_0,
338
+ params.ptr_A0,
339
+ thread_idx,
340
+ MatrixCoord(
341
+ threadblock_tile_idx.m() * B2bMma::Shape0::kM,
342
+ threadblock_tile_idx.k() * B2bMma::Shape0::kK
343
+ )
344
+ );
345
+
346
+ typename B2bMma::IteratorB0 iterator_B0(
347
+ params.iterator_B0,
348
+ params.problem_size_0,
349
+ params.ptr_B0,
350
+ thread_idx,
351
+ MatrixCoord(
352
+ threadblock_tile_idx.k() * B2bMma::Shape0::kK,
353
+ threadblock_tile_idx.n() * B2bMma::Shape0::kN
354
+ )
355
+ );
356
+
357
+ typename B2bMma::IteratorB1 iterator_B1(
358
+ params.iterator_B1,
359
+ params.problem_size_1,
360
+ params.ptr_B1,
361
+ thread_idx,
362
+ MatrixCoord(
363
+ threadblock_tile_idx.k() * B2bMma::Shape1::kK,
364
+ threadblock_tile_idx.n() * B2bMma::Shape1::kN
365
+ )
366
+ );
367
+
368
+
369
+ // Broadcast the warp_id computed by lane 0 to ensure dependent code
370
+ // is compiled as warp-uniform.
371
+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
372
+ int lane_idx = threadIdx.x % 32;
373
+
374
+ // Construct iterators to accumulator scale/bias vector
375
+ typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0(
376
+ params.ptr_Scale0,
377
+ {1, params.problem_size_0.K},
378
+ thread_idx,
379
+ warp_idx,
380
+ MatrixCoord(
381
+ 0, threadblock_tile_idx.n() * B2bMma::Shape0::kN
382
+ )
383
+ );
384
+
385
+ typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0(
386
+ params.ptr_Bias0,
387
+ {1, params.problem_size_0.K},
388
+ thread_idx,
389
+ warp_idx,
390
+ MatrixCoord(
391
+ 0, threadblock_tile_idx.n() * B2bMma::Shape0::kN
392
+ )
393
+ );
394
+
395
+
396
+ //
397
+ // Main loop
398
+ //
399
+
400
+ EpilogueOutputOp0 output_op_0(params.output_op_0);
401
+
402
+ // Construct thread-scoped matrix multiply
403
+ B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
404
+
405
+ typename B2bMma::FragmentC0 src_accum;
406
+ typename B2bMma::FragmentC1 accumulators;
407
+
408
+ src_accum.clear();
409
+ accumulators.clear();
410
+
411
+ // Compute threadblock-scoped matrix multiply-add
412
+ b2bMma(params.gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
413
+ iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
414
+
415
+ //
416
+ // Epilogue
417
+ //
418
+
419
+ EpilogueOutputOp1 output_op_1(params.output_op_1);
420
+
421
+ // Construct the semaphore.
422
+ int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m();
423
+
424
+ Semaphore semaphore(params.semaphore + block_idx, thread_idx);
425
+
426
+ // Compute logical position within grid
427
+ threadblock_tile_idx =
428
+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
429
+
430
+ // If performing a reduction via split-K, fetch the initial synchronization
431
+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
432
+
433
+ // Fetch the synchronization lock initially but do not block.
434
+ semaphore.fetch();
435
+
436
+ // Indicate which position in a serial reduction the output operator is currently updating
437
+ output_op_1.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k());
438
+ }
439
+
440
+ MatrixCoord threadblock_offset(
441
+ threadblock_tile_idx.m() * B2bMma::Shape1::kM,
442
+ threadblock_tile_idx.n() * B2bMma::Shape1::kN
443
+ );
444
+
445
+ // Tile iterator writing to destination tensor
446
+ typename Epilogue::OutputTileIterator iterator_D1(
447
+ params.iterator_D1,
448
+ params.ptr_D1,
449
+ ConvOutputIteratorParameter::extent(params.problem_size_1),
450
+ thread_idx,
451
+ threadblock_offset
452
+ );
453
+
454
+ // Tile iterator reading from source accumulator tensor
455
+ typename Epilogue::OutputTileIterator iterator_C1(
456
+ params.iterator_C1,
457
+ params.ptr_C1,
458
+ ConvOutputIteratorParameter::extent(params.problem_size_1),
459
+ thread_idx,
460
+ threadblock_offset
461
+ );
462
+
463
+
464
+ // Construct the epilogue
465
+ Epilogue epilogue(
466
+ shared_storage.epilogue,
467
+ thread_idx,
468
+ warp_idx,
469
+ lane_idx);
470
+
471
+ // Wait on the semaphore - this latency may have been covered by iterator construction
472
+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
473
+
474
+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor.
475
+ if (threadblock_tile_idx.k()) {
476
+ iterator_C1 = iterator_D1;
477
+ }
478
+
479
+ semaphore.wait(threadblock_tile_idx.k());
480
+
481
+ __threadfence();
482
+ }
483
+ // Each split-k-slice writes to a unique tensor location
484
+ else if (params.split_k_mode == SplitKMode::kParallel) {
485
+ iterator_D1.add_pointer_offset(threadblock_tile_idx.k() *
486
+ cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size_1));
487
+ }
488
+
489
+ // Run efficient epilogue
490
+ epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
491
+
492
+ //
493
+ // Release the semaphore
494
+ //
495
+
496
+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) {
497
+
498
+ int lock = 0;
499
+ if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) {
500
+
501
+ // The final threadblock resets the semaphore for subsequent grids.
502
+ lock = 0;
503
+ }
504
+ else {
505
+ // Otherwise, the semaphore is incremented
506
+ lock = threadblock_tile_idx.k() + 1;
507
+ }
508
+
509
+ semaphore.release(lock);
510
+ }
511
+ }
512
+ };
513
+
514
+ /////////////////////////////////////////////////////////////////////////////////////////////////
515
+
516
+ } // namespace kernel
517
+ } // namespace conv
518
+ } // namespace cutlass
519
+
520
+ /////////////////////////////////////////////////////////////////////////////////////////////////
521
+
build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop.h ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief
34
+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
35
+ matrix multiply-add with the appropriate threadblock-scoped epilogue.
36
+ */
37
+
38
+ #pragma once
39
+
40
+ #include "cutlass/cutlass.h"
41
+ #include "cutlass/conv/kernel/default_conv2d.h"
42
+
43
+ #include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
44
+ #include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
45
+ #include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
46
+ #include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
47
+
48
+ #include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
49
+ #include "cutlass/transform/threadblock/vector_iterator.h"
50
+ #include "cutlass/transform/warp/vector_fragment_iterator.h"
51
+
52
+ #include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
53
+
54
+ #include "kernel/b2b_implicit_gemm_convolution.h"
55
+ #include "threadblock/b2b_implicit_gemm_pipelined.h"
56
+ #include "threadblock/b2b_implicit_gemm_multistage.h"
57
+
58
+ /////////////////////////////////////////////////////////////////////////////////////////////////
59
+
60
+ namespace cutlass {
61
+ namespace conv {
62
+ namespace kernel {
63
+
64
+ /////////////////////////////////////////////////////////////////////////////////////////////////
65
+ /// Defines a kernel for Conv2dFprop
66
+ template <
67
+ typename ElementA,
68
+ typename LayoutA,
69
+ typename ElementB,
70
+ typename LayoutB,
71
+ typename ElementC,
72
+ typename LayoutC,
73
+ typename ElementAccumulator,
74
+ typename OperatorClass,
75
+ typename ArchTag,
76
+ typename ThreadblockShape0,
77
+ typename ThreadblockShape1,
78
+ typename WarpShape0,
79
+ typename WarpShape1,
80
+ typename InstructionShape,
81
+ typename EpilogueOutputOp0,
82
+ typename EpilogueOutputOp1,
83
+ typename ThreadblockSwizzle,
84
+ int Stages,
85
+ typename MathOperatorTag,
86
+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic,
87
+ bool SmemAccumulator = false
88
+ > struct DefaultB2bConv2dFprop;
89
+
90
+ } // namespace kernel
91
+ } // namespace conv
92
+ } // namespace cutlass
93
+
94
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm75.h ADDED
@@ -0,0 +1,749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief
34
+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
35
+ matrix multiply-add with the appropriate threadblock-scoped epilogue.
36
+ */
37
+
38
+ #pragma once
39
+
40
+ #include "cutlass/cutlass.h"
41
+ #include "cutlass/conv/kernel/default_conv2d.h"
42
+
43
+ #include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
44
+ #include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
45
+ #include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
46
+ #include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
47
+
48
+ #include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
49
+ #include "cutlass/transform/threadblock/vector_iterator.h"
50
+ #include "cutlass/transform/warp/vector_fragment_iterator.h"
51
+
52
+ #include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
53
+
54
+ #include "kernel/default_b2b_conv2d_fprop.h"
55
+ #include "kernel/b2b_implicit_gemm_convolution.h"
56
+ #include "threadblock/b2b_implicit_gemm_pipelined.h"
57
+
58
+ /////////////////////////////////////////////////////////////////////////////////////////////////
59
+
60
+ namespace cutlass {
61
+ namespace conv {
62
+ namespace kernel {
63
+
64
+ /////////////////////////////////////////////////////////////////////////////////////////////////
65
+ // OpClassTensorOp convolutions
66
+ /////////////////////////////////////////////////////////////////////////////////////////////////
67
+
68
+ /// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm
69
+ /// and 2 stage pipeline.
70
+ template <
71
+ typename ElementA,
72
+ typename LayoutA,
73
+ typename ElementB,
74
+ typename LayoutB,
75
+ typename ElementC,
76
+ typename LayoutC,
77
+ typename ElementAccumulator,
78
+ typename ArchTag,
79
+ typename ThreadblockShape0,
80
+ typename ThreadblockShape1,
81
+ typename WarpShape0,
82
+ typename WarpShape1,
83
+ typename InstructionShape,
84
+ typename EpilogueOutputOp0,
85
+ typename EpilogueOutputOp1,
86
+ typename ThreadblockSwizzle,
87
+ typename MathOperatorTag
88
+ >
89
+ struct DefaultB2bConv2dFprop <
90
+ ElementA,
91
+ LayoutA,
92
+ ElementB,
93
+ LayoutB,
94
+ ElementC,
95
+ LayoutC,
96
+ ElementAccumulator,
97
+ arch::OpClassTensorOp,
98
+ ArchTag,
99
+ ThreadblockShape0,
100
+ ThreadblockShape1,
101
+ WarpShape0,
102
+ WarpShape1,
103
+ InstructionShape,
104
+ EpilogueOutputOp0,
105
+ EpilogueOutputOp1,
106
+ ThreadblockSwizzle,
107
+ 2,
108
+ MathOperatorTag,
109
+ IteratorAlgorithm::kAnalytic
110
+ > {
111
+
112
+ // Define the core components from GEMM
113
+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
114
+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
115
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
116
+ 2, MathOperatorTag>;
117
+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
118
+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
119
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
120
+ 2, MathOperatorTag>;
121
+
122
+ // Define iterators over tiles from the A operand
123
+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
124
+ using IteratorA0 =
125
+ cutlass::conv::threadblock::TileIterator<
126
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
127
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
128
+ ElementA, LayoutA,
129
+ ThreadMapA0
130
+ >
131
+ >;
132
+
133
+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
134
+
135
+ // Define iterators over tiles from the B operand
136
+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
137
+ using IteratorB0 =
138
+ cutlass::conv::threadblock::TileIterator<
139
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
140
+ cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
141
+ ElementB, LayoutB,
142
+ ThreadMapB0
143
+ >
144
+ >;
145
+
146
+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
147
+
148
+ // Use fragment iterator for A operand
149
+ using AccumulatorLayout = cutlass::layout::ColumnMajor;
150
+ using FragmentIteratorA1 =
151
+ cutlass::gemm::warp::MmaTensorOpFragmentIterator<
152
+ cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
153
+ cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
154
+ MmaCore1::Shape::kK, //kBlocksColumn
155
+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
156
+
157
+ /// Define iterators over tiles from scale/bias vectors
158
+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
159
+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
160
+ static int const kElementsPerAccess = 2;
161
+ using IteratorAccumulatorScaleBias =
162
+ cutlass::transform::threadblock::VectorIterator<
163
+ cutlass::transform::threadblock::PredicatedVectorAccessIterator<
164
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
165
+ cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
166
+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
167
+ >;
168
+
169
+ // Warp-level iterators to load scale and bias vectors
170
+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
171
+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
172
+ LayoutScaleBias, InstructionShape, kElementsPerAccess>;
173
+
174
+ // Define iterators over tiles from the B operand
175
+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
176
+ using IteratorB1 =
177
+ cutlass::conv::threadblock::TileIterator<
178
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
179
+ cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
180
+ ElementB, LayoutB,
181
+ ThreadMapB1
182
+ >
183
+ >;
184
+
185
+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
186
+
187
+ // Warp-level GEMM components
188
+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
189
+ using MmaPolicy0 = typename MmaCore0::MmaPolicy;
190
+ using MmaPolicy1 = typename MmaCore1::MmaPolicy;
191
+
192
+ // Define the Mma
193
+ using B2bMma = threadblock::B2bImplicitGemmPipelined<
194
+ ThreadblockShape0,
195
+ IteratorA0,
196
+ SmemIteratorA0,
197
+ IteratorB0,
198
+ SmemIteratorB0,
199
+ ThreadblockShape1,
200
+ FragmentIteratorA1,
201
+ IteratorAccumulatorScaleBias,
202
+ FragmentIteratorA1ScaleBias,
203
+ IteratorB1,
204
+ SmemIteratorB1,
205
+ ElementC,
206
+ LayoutC,
207
+ EpilogueOutputOp0,
208
+ MmaPolicy0,
209
+ MmaPolicy1
210
+ >;
211
+
212
+ // Define the epilogue
213
+ using Epilogue = typename detail::DefaultConvEpilogue<
214
+ ArchTag,
215
+ ThreadblockShape1,
216
+ WarpMmaTensorOp1,
217
+ 1,
218
+ EpilogueOutputOp1
219
+ >::Epilogue;
220
+
221
+ // Define the kernel
222
+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
223
+ B2bMma,
224
+ Epilogue,
225
+ ThreadblockSwizzle,
226
+ conv::Operator::kFprop
227
+ >;
228
+ };
229
+
230
+ /////////////////////////////////////////////////////////////////////////////////////////////////
231
+
232
+ /// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage
233
+ /// pipeline with interleaved layout.
234
+ template <
235
+ typename ElementA,
236
+ typename ElementB,
237
+ typename ElementC,
238
+ typename LayoutC,
239
+ typename ElementAccumulator,
240
+ typename ArchTag,
241
+ typename ThreadblockShape0,
242
+ typename ThreadblockShape1,
243
+ typename WarpShape0,
244
+ typename WarpShape1,
245
+ typename InstructionShape,
246
+ typename EpilogueOutputOp0,
247
+ typename EpilogueOutputOp1,
248
+ typename ThreadblockSwizzle,
249
+ typename MathOperatorTag,
250
+ int InterleavedK
251
+ >
252
+ struct DefaultB2bConv2dFprop <
253
+ ElementA,
254
+ layout::TensorNCxHWx<InterleavedK>,
255
+ ElementB,
256
+ layout::TensorCxRSKx<InterleavedK>,
257
+ ElementC,
258
+ LayoutC,
259
+ ElementAccumulator,
260
+ arch::OpClassTensorOp,
261
+ ArchTag,
262
+ ThreadblockShape0,
263
+ ThreadblockShape1,
264
+ WarpShape0,
265
+ WarpShape1,
266
+ InstructionShape,
267
+ EpilogueOutputOp0,
268
+ EpilogueOutputOp1,
269
+ ThreadblockSwizzle,
270
+ 2,
271
+ MathOperatorTag,
272
+ IteratorAlgorithm::kAnalytic,
273
+ false
274
+ > {
275
+
276
+ // Define the core components from GEMM
277
+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
278
+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
279
+ ElementB, layout::RowMajorInterleaved<InterleavedK>,
280
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp,
281
+ 2, MathOperatorTag, true>;
282
+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
283
+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
284
+ ElementB, layout::RowMajorInterleaved<InterleavedK>,
285
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp,
286
+ 2, MathOperatorTag, true>;
287
+
288
+ // Define iterators over tiles from the A operand
289
+ // Note GEMM shared memory threadmap is used here because conv global memory
290
+ // layout needs to be mapped to fprop which is similar to the crosswise
291
+ // layout which is used by the interleaved GEMM shared memory threadmap.
292
+ // The Interleaved GEMM global memory layout is similar to the congruous
293
+ // layout.
294
+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
295
+ using IteratorA0 =
296
+ cutlass::conv::threadblock::TileIterator<
297
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
298
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
299
+ ElementA, layout::TensorNCxHWx<InterleavedK>,
300
+ ThreadMapA0
301
+ >
302
+ >;
303
+
304
+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
305
+
306
+ // Define iterators over tiles from the B operand
307
+ // Note GEMM shared memory threadmap is used here because conv global memory
308
+ // layout needs to be mapped to fprop which is similar to the crosswise
309
+ // layout which is used by the interleaved GEMM shared memory threadmap.
310
+ // The Interleaved GEMM global memory layout is similar to the congruous
311
+ // layout.
312
+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
313
+ using IteratorB0 =
314
+ cutlass::conv::threadblock::TileIterator<
315
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
316
+ cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
317
+ ElementB, layout::TensorCxRSKx<InterleavedK>,
318
+ ThreadMapB0
319
+ >
320
+ >;
321
+
322
+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
323
+
324
+ // Use fragment iterator for A operand
325
+ using AccumulatorLayout = cutlass::layout::RowMajor;
326
+ using FragmentIteratorA1 =
327
+ cutlass::gemm::warp::MmaTensorOpFragmentIterator<
328
+ cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
329
+ cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
330
+ MmaCore1::Shape::kK, //kBlocksColumn
331
+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
332
+
333
+ /// Define iterators over tiles from scale/bias vectors
334
+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
335
+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
336
+ static int const kElementsPerAccess = 4;
337
+ using IteratorAccumulatorScaleBias =
338
+ cutlass::transform::threadblock::VectorIterator<
339
+ cutlass::transform::threadblock::PredicatedVectorAccessIterator<
340
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
341
+ cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
342
+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
343
+ >;
344
+
345
+ // Warp-level iterators to load scale and bias vectors
346
+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
347
+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
348
+ LayoutScaleBias, InstructionShape, kElementsPerAccess>;
349
+
350
+ // Define iterators over tiles from the B operand
351
+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
352
+ using IteratorB1 =
353
+ cutlass::conv::threadblock::TileIterator<
354
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
355
+ cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
356
+ ElementB, layout::TensorCxRSKx<InterleavedK>,
357
+ ThreadMapB1
358
+ >
359
+ >;
360
+
361
+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
362
+
363
+ // Warp-level GEMM components
364
+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
365
+ using MmaPolicy0 = typename MmaCore0::MmaPolicy;
366
+ using MmaPolicy1 = typename MmaCore1::MmaPolicy;
367
+
368
+ // Define the Mma
369
+ using B2bMma = threadblock::B2bImplicitGemmPipelined<
370
+ ThreadblockShape0,
371
+ IteratorA0,
372
+ SmemIteratorA0,
373
+ IteratorB0,
374
+ SmemIteratorB0,
375
+ ThreadblockShape1,
376
+ FragmentIteratorA1,
377
+ IteratorAccumulatorScaleBias,
378
+ FragmentIteratorA1ScaleBias,
379
+ IteratorB1,
380
+ SmemIteratorB1,
381
+ ElementC,
382
+ LayoutC,
383
+ EpilogueOutputOp0,
384
+ MmaPolicy0,
385
+ MmaPolicy1
386
+ >;
387
+
388
+ // Define the epilogue
389
+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
390
+ ThreadblockShape1,
391
+ WarpMmaTensorOp1,
392
+ 1,
393
+ EpilogueOutputOp1,
394
+ EpilogueOutputOp1::kCount,
395
+ InterleavedK
396
+ >::Epilogue;
397
+
398
+ // Define the kernel
399
+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
400
+ B2bMma,
401
+ Epilogue,
402
+ ThreadblockSwizzle,
403
+ conv::Operator::kFprop
404
+ >;
405
+ };
406
+
407
+ /////////////////////////////////////////////////////////////////////////////////////////////////
408
+
409
+ /// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm
410
+ /// and 2 stage pipeline.
411
+ template <
412
+ typename ElementA,
413
+ typename LayoutA,
414
+ typename ElementB,
415
+ typename LayoutB,
416
+ typename ElementC,
417
+ typename LayoutC,
418
+ typename ElementAccumulator,
419
+ typename ArchTag,
420
+ typename ThreadblockShape0,
421
+ typename ThreadblockShape1,
422
+ typename WarpShape0,
423
+ typename WarpShape1,
424
+ typename InstructionShape,
425
+ typename EpilogueOutputOp0,
426
+ typename EpilogueOutputOp1,
427
+ typename ThreadblockSwizzle,
428
+ typename MathOperatorTag
429
+ >
430
+ struct DefaultB2bConv2dFprop <
431
+ ElementA,
432
+ LayoutA,
433
+ ElementB,
434
+ LayoutB,
435
+ ElementC,
436
+ LayoutC,
437
+ ElementAccumulator,
438
+ arch::OpClassTensorOp,
439
+ ArchTag,
440
+ ThreadblockShape0,
441
+ ThreadblockShape1,
442
+ WarpShape0,
443
+ WarpShape1,
444
+ InstructionShape,
445
+ EpilogueOutputOp0,
446
+ EpilogueOutputOp1,
447
+ ThreadblockSwizzle,
448
+ 2,
449
+ MathOperatorTag,
450
+ IteratorAlgorithm::kOptimized
451
+ > {
452
+
453
+ // Define the core components from GEMM
454
+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
455
+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
456
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
457
+ 2, MathOperatorTag>;
458
+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
459
+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
460
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
461
+ 2, MathOperatorTag>;
462
+
463
+ // Define iterators over tiles from the A operand
464
+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
465
+ using IteratorA0 =
466
+ cutlass::conv::threadblock::TileIterator<
467
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
468
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
469
+ ElementA, LayoutA,
470
+ ThreadMapA0
471
+ >
472
+ >;
473
+
474
+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
475
+
476
+ // Define iterators over tiles from the B operand
477
+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
478
+ using IteratorB0 =
479
+ cutlass::conv::threadblock::TileIterator<
480
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
481
+ cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
482
+ ElementB, LayoutB,
483
+ ThreadMapB0
484
+ >
485
+ >;
486
+
487
+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
488
+
489
+ // Use fragment iterator for A operand
490
+ using AccumulatorLayout = cutlass::layout::ColumnMajor;
491
+ using FragmentIteratorA1 =
492
+ cutlass::gemm::warp::MmaTensorOpFragmentIterator<
493
+ cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
494
+ cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
495
+ MmaCore1::Shape::kK, //kBlocksColumn
496
+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
497
+
498
+ /// Define iterators over tiles from scale/bias vectors
499
+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
500
+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
501
+ static int const kElementsPerAccess = 2;
502
+ using IteratorAccumulatorScaleBias =
503
+ cutlass::transform::threadblock::VectorIterator<
504
+ cutlass::transform::threadblock::PredicatedVectorAccessIterator<
505
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
506
+ cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
507
+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
508
+ >;
509
+
510
+ // Warp-level iterators to load scale and bias vectors
511
+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
512
+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
513
+ LayoutScaleBias, InstructionShape, kElementsPerAccess>;
514
+
515
+ // Define iterators over tiles from the B operand
516
+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
517
+ using IteratorB1 =
518
+ cutlass::conv::threadblock::TileIterator<
519
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
520
+ cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
521
+ ElementB, LayoutB,
522
+ ThreadMapB1
523
+ >
524
+ >;
525
+
526
+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
527
+
528
+ // Warp-level GEMM components
529
+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
530
+ using MmaPolicy0 = typename MmaCore0::MmaPolicy;
531
+ using MmaPolicy1 = typename MmaCore1::MmaPolicy;
532
+
533
+ // Define the Mma
534
+ using B2bMma = threadblock::B2bImplicitGemmPipelined<
535
+ ThreadblockShape0,
536
+ IteratorA0,
537
+ SmemIteratorA0,
538
+ IteratorB0,
539
+ SmemIteratorB0,
540
+ ThreadblockShape1,
541
+ FragmentIteratorA1,
542
+ IteratorAccumulatorScaleBias,
543
+ FragmentIteratorA1ScaleBias,
544
+ IteratorB1,
545
+ SmemIteratorB1,
546
+ ElementC,
547
+ LayoutC,
548
+ EpilogueOutputOp0,
549
+ MmaPolicy0,
550
+ MmaPolicy1
551
+ >;
552
+
553
+ // Define the epilogue
554
+ using Epilogue = typename detail::DefaultConvEpilogue<
555
+ ArchTag,
556
+ ThreadblockShape1,
557
+ WarpMmaTensorOp1,
558
+ 1,
559
+ EpilogueOutputOp1
560
+ >::Epilogue;
561
+
562
+ // Define the kernel
563
+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
564
+ B2bMma,
565
+ Epilogue,
566
+ ThreadblockSwizzle,
567
+ conv::Operator::kFprop
568
+ >;
569
+ };
570
+
571
+ /////////////////////////////////////////////////////////////////////////////////////////////////
572
+
573
+ /// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage
574
+ /// pipeline with interleaved layout.
575
+ template <
576
+ typename ElementA,
577
+ typename ElementB,
578
+ typename ElementC,
579
+ typename LayoutC,
580
+ typename ElementAccumulator,
581
+ typename ArchTag,
582
+ typename ThreadblockShape0,
583
+ typename ThreadblockShape1,
584
+ typename WarpShape0,
585
+ typename WarpShape1,
586
+ typename InstructionShape,
587
+ typename EpilogueOutputOp0,
588
+ typename EpilogueOutputOp1,
589
+ typename ThreadblockSwizzle,
590
+ typename MathOperatorTag,
591
+ int InterleavedK
592
+ >
593
+ struct DefaultB2bConv2dFprop <
594
+ ElementA,
595
+ layout::TensorNCxHWx<InterleavedK>,
596
+ ElementB,
597
+ layout::TensorCxRSKx<InterleavedK>,
598
+ ElementC,
599
+ LayoutC,
600
+ ElementAccumulator,
601
+ arch::OpClassTensorOp,
602
+ ArchTag,
603
+ ThreadblockShape0,
604
+ ThreadblockShape1,
605
+ WarpShape0,
606
+ WarpShape1,
607
+ InstructionShape,
608
+ EpilogueOutputOp0,
609
+ EpilogueOutputOp1,
610
+ ThreadblockSwizzle,
611
+ 2,
612
+ MathOperatorTag,
613
+ IteratorAlgorithm::kOptimized
614
+ > {
615
+
616
+ // Define the core components from GEMM
617
+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
618
+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
619
+ ElementB, layout::RowMajorInterleaved<InterleavedK>,
620
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp,
621
+ 2, MathOperatorTag, true>;
622
+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
623
+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
624
+ ElementB, layout::RowMajorInterleaved<InterleavedK>,
625
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp,
626
+ 2, MathOperatorTag, true>;
627
+
628
+ // Define iterators over tiles from the A operand
629
+ // Note GEMM shared memory threadmap is used here because conv global memory
630
+ // layout needs to be mapped to fprop which is similar to the crosswise
631
+ // layout which is used by the interleaved GEMM shared memory threadmap.
632
+ // The Interleaved GEMM global memory layout is similar to the congruous
633
+ // layout.
634
+
635
+ // Define iterators over tiles from the A operand
636
+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
637
+ using IteratorA0 =
638
+ cutlass::conv::threadblock::TileIterator<
639
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
640
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
641
+ ElementA, layout::TensorNCxHWx<InterleavedK>,
642
+ ThreadMapA0
643
+ >
644
+ >;
645
+
646
+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
647
+
648
+ // Define iterators over tiles from the B operand
649
+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
650
+ using IteratorB0 =
651
+ cutlass::conv::threadblock::TileIterator<
652
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
653
+ cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
654
+ ElementB, layout::TensorCxRSKx<InterleavedK>,
655
+ ThreadMapB0
656
+ >
657
+ >;
658
+
659
+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
660
+
661
+ // Use fragment iterator for A operand
662
+ using AccumulatorLayout = cutlass::layout::RowMajor;
663
+ using FragmentIteratorA1 =
664
+ cutlass::gemm::warp::MmaTensorOpFragmentIterator<
665
+ cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
666
+ cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
667
+ MmaCore1::Shape::kK, //kBlocksColumn
668
+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
669
+
670
+ /// Define iterators over tiles from scale/bias vectors
671
+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
672
+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
673
+ static int const kElementsPerAccess = 4;
674
+ using IteratorAccumulatorScaleBias =
675
+ cutlass::transform::threadblock::VectorIterator<
676
+ cutlass::transform::threadblock::PredicatedVectorAccessIterator<
677
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
678
+ cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
679
+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
680
+ >;
681
+
682
+ // Warp-level iterators to load scale and bias vectors
683
+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
684
+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
685
+ LayoutScaleBias, InstructionShape, kElementsPerAccess>;
686
+
687
+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
688
+ using IteratorB1 =
689
+ cutlass::conv::threadblock::TileIterator<
690
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
691
+ cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
692
+ ElementB, layout::TensorCxRSKx<InterleavedK>,
693
+ ThreadMapB1
694
+ >
695
+ >;
696
+
697
+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
698
+
699
+ // Warp-level GEMM components
700
+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
701
+ using MmaPolicy0 = typename MmaCore0::MmaPolicy;
702
+ using MmaPolicy1 = typename MmaCore1::MmaPolicy;
703
+
704
+ // Define the Mma
705
+ using B2bMma = threadblock::B2bImplicitGemmPipelined<
706
+ ThreadblockShape0,
707
+ IteratorA0,
708
+ SmemIteratorA0,
709
+ IteratorB0,
710
+ SmemIteratorB0,
711
+ ThreadblockShape1,
712
+ FragmentIteratorA1,
713
+ IteratorAccumulatorScaleBias,
714
+ FragmentIteratorA1ScaleBias,
715
+ IteratorB1,
716
+ SmemIteratorB1,
717
+ ElementC,
718
+ LayoutC,
719
+ EpilogueOutputOp0,
720
+ MmaPolicy0,
721
+ MmaPolicy1
722
+ >;
723
+
724
+ // Define the epilogue
725
+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
726
+ ThreadblockShape1,
727
+ WarpMmaTensorOp1,
728
+ 1,
729
+ EpilogueOutputOp1,
730
+ EpilogueOutputOp1::kCount,
731
+ InterleavedK
732
+ >::Epilogue;
733
+
734
+ // Define the kernel
735
+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
736
+ B2bMma,
737
+ Epilogue,
738
+ ThreadblockSwizzle,
739
+ conv::Operator::kFprop
740
+ >;
741
+ };
742
+
743
+ /////////////////////////////////////////////////////////////////////////////////////////////////
744
+
745
+ } // namespace kernel
746
+ } // namespace conv
747
+ } // namespace cutlass
748
+
749
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h ADDED
@@ -0,0 +1,740 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief
34
+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
35
+ matrix multiply-add with the appropriate threadblock-scoped epilogue.
36
+ */
37
+
38
+ #pragma once
39
+
40
+ #include "cutlass/cutlass.h"
41
+ #include "cutlass/conv/kernel/default_conv2d.h"
42
+
43
+ #include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
44
+ #include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
45
+ #include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
46
+ #include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
47
+
48
+ #include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
49
+ #include "cutlass/transform/threadblock/vector_iterator.h"
50
+ #include "cutlass/transform/warp/vector_fragment_iterator.h"
51
+
52
+ #include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
53
+
54
+ #include "kernel/default_b2b_conv2d_fprop.h"
55
+ #include "kernel/b2b_implicit_gemm_convolution.h"
56
+ #include "threadblock/b2b_implicit_gemm_multistage.h"
57
+
58
+ /////////////////////////////////////////////////////////////////////////////////////////////////
59
+
60
+ namespace cutlass {
61
+ namespace conv {
62
+ namespace kernel {
63
+
64
+ /////////////////////////////////////////////////////////////////////////////////////////////////
65
+ // OpClassTensorOp convolutions
66
+ /////////////////////////////////////////////////////////////////////////////////////////////////
67
+
68
+ /// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
69
+ /// pipeline.
70
+ template <
71
+ typename ElementA,
72
+ typename LayoutA,
73
+ typename ElementB,
74
+ typename LayoutB,
75
+ typename ElementC,
76
+ typename LayoutC,
77
+ typename ElementAccumulator,
78
+ typename ArchTag,
79
+ typename ThreadblockShape0,
80
+ typename ThreadblockShape1,
81
+ typename WarpShape0,
82
+ typename WarpShape1,
83
+ typename InstructionShape,
84
+ typename EpilogueOutputOp0,
85
+ typename EpilogueOutputOp1,
86
+ typename ThreadblockSwizzle,
87
+ int Stages,
88
+ typename MathOperatorTag
89
+ >
90
+ struct DefaultB2bConv2dFprop <
91
+ ElementA,
92
+ LayoutA,
93
+ ElementB,
94
+ LayoutB,
95
+ ElementC,
96
+ LayoutC,
97
+ ElementAccumulator,
98
+ arch::OpClassTensorOp,
99
+ ArchTag,
100
+ ThreadblockShape0,
101
+ ThreadblockShape1,
102
+ WarpShape0,
103
+ WarpShape1,
104
+ InstructionShape,
105
+ EpilogueOutputOp0,
106
+ EpilogueOutputOp1,
107
+ ThreadblockSwizzle,
108
+ Stages,
109
+ MathOperatorTag,
110
+ IteratorAlgorithm::kAnalytic
111
+ > {
112
+
113
+ // Define the core components from GEMM
114
+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
115
+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
116
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
117
+ Stages, MathOperatorTag>;
118
+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
119
+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
120
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
121
+ Stages, MathOperatorTag>;
122
+
123
+ // Define iterators over tiles from the A operand
124
+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
125
+ using IteratorA0 =
126
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
127
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
128
+ ElementA, LayoutA,
129
+ ThreadMapA0
130
+ >;
131
+
132
+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
133
+
134
+ // Define iterators over tiles from the B operand
135
+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
136
+ using IteratorB0 =
137
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
138
+ cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
139
+ ElementB, LayoutB,
140
+ ThreadMapB0
141
+ >;
142
+
143
+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
144
+
145
+ // Use fragment iterator for A operand
146
+ using AccumulatorLayout = cutlass::layout::ColumnMajor;
147
+ using FragmentIteratorA1 =
148
+ cutlass::gemm::warp::MmaTensorOpFragmentIterator<
149
+ cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
150
+ cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
151
+ MmaCore1::Shape::kK, //kBlocksColumn
152
+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
153
+
154
+ /// Define iterators over tiles from scale/bias vectors
155
+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
156
+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
157
+ static int const kElementsPerAccess = 2;
158
+ using IteratorAccumulatorScaleBias =
159
+ cutlass::transform::threadblock::VectorIterator<
160
+ cutlass::transform::threadblock::PredicatedVectorAccessIterator<
161
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
162
+ cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
163
+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
164
+ >;
165
+
166
+ // Warp-level iterators to load scale and bias vectors
167
+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
168
+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
169
+ LayoutScaleBias, InstructionShape, kElementsPerAccess>;
170
+
171
+ // Define iterators over tiles from the B operand
172
+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
173
+ using IteratorB1 =
174
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
175
+ cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
176
+ ElementB, LayoutB,
177
+ ThreadMapB1
178
+ >;
179
+
180
+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
181
+
182
+ // Warp-level GEMM components
183
+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
184
+ using MmaPolicy0 = typename MmaCore0::MmaPolicy;
185
+ using MmaPolicy1 = typename MmaCore1::MmaPolicy;
186
+
187
+ // Define the Mma
188
+ using B2bMma = threadblock::B2bImplicitGemmMultistage<
189
+ ThreadblockShape0,
190
+ IteratorA0,
191
+ SmemIteratorA0,
192
+ arch::CacheOperation::Always,
193
+ IteratorB0,
194
+ SmemIteratorB0,
195
+ arch::CacheOperation::Global,
196
+ ThreadblockShape1,
197
+ FragmentIteratorA1,
198
+ IteratorAccumulatorScaleBias,
199
+ FragmentIteratorA1ScaleBias,
200
+ IteratorB1,
201
+ SmemIteratorB1,
202
+ arch::CacheOperation::Global,
203
+ EpilogueOutputOp0,
204
+ MmaPolicy0,
205
+ MmaPolicy1,
206
+ Stages
207
+ >;
208
+
209
+ // Define the epilogue
210
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
211
+ ThreadblockShape1,
212
+ WarpMmaTensorOp1,
213
+ 1,
214
+ EpilogueOutputOp1,
215
+ EpilogueOutputOp1::kCount
216
+ >::Epilogue;
217
+
218
+ // Define the kernel
219
+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
220
+ B2bMma,
221
+ Epilogue,
222
+ ThreadblockSwizzle,
223
+ conv::Operator::kFprop
224
+ >;
225
+ };
226
+
227
+ /////////////////////////////////////////////////////////////////////////////////////////////////
228
+
229
+ /// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
230
+ /// pipeline with interleaved layout.
231
+ template <
232
+ typename ElementA,
233
+ typename ElementB,
234
+ typename ElementC,
235
+ typename LayoutC,
236
+ typename ElementAccumulator,
237
+ typename ArchTag,
238
+ typename ThreadblockShape0,
239
+ typename ThreadblockShape1,
240
+ typename WarpShape0,
241
+ typename WarpShape1,
242
+ typename InstructionShape,
243
+ typename EpilogueOutputOp0,
244
+ typename EpilogueOutputOp1,
245
+ typename ThreadblockSwizzle,
246
+ int Stages,
247
+ typename MathOperatorTag,
248
+ int InterleavedK
249
+ >
250
+ struct DefaultB2bConv2dFprop <
251
+ ElementA,
252
+ layout::TensorNCxHWx<InterleavedK>,
253
+ ElementB,
254
+ layout::TensorCxRSKx<InterleavedK>,
255
+ ElementC,
256
+ LayoutC,
257
+ ElementAccumulator,
258
+ arch::OpClassTensorOp,
259
+ ArchTag,
260
+ ThreadblockShape0,
261
+ ThreadblockShape1,
262
+ WarpShape0,
263
+ WarpShape1,
264
+ InstructionShape,
265
+ EpilogueOutputOp0,
266
+ EpilogueOutputOp1,
267
+ ThreadblockSwizzle,
268
+ Stages,
269
+ MathOperatorTag,
270
+ IteratorAlgorithm::kAnalytic
271
+ > {
272
+
273
+ // Define the core components from GEMM
274
+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
275
+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
276
+ ElementB, layout::RowMajorInterleaved<InterleavedK>,
277
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp,
278
+ Stages, MathOperatorTag, true>;
279
+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
280
+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
281
+ ElementB, layout::RowMajorInterleaved<InterleavedK>,
282
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp,
283
+ Stages, MathOperatorTag, true>;
284
+
285
+ // Define iterators over tiles from the A operand
286
+ // Note GEMM shared memory threadmap is used here because conv global memory
287
+ // layout needs to be mapped to fprop which is similar to the crosswise
288
+ // layout which is used by the interleaved GEMM shared memory threadmap.
289
+ // The Interleaved GEMM global memory layout is similar to the congruous
290
+ // layout.
291
+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
292
+ using IteratorA0 =
293
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
294
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
295
+ ElementA, layout::TensorNCxHWx<InterleavedK>,
296
+ ThreadMapA0
297
+ >;
298
+
299
+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
300
+
301
+ // Define iterators over tiles from the B operand
302
+ // Note GEMM shared memory threadmap is used here because conv global memory
303
+ // layout needs to be mapped to fprop which is similar to the crosswise
304
+ // layout which is used by the interleaved GEMM shared memory threadmap.
305
+ // The Interleaved GEMM global memory layout is similar to the congruous
306
+ // layout.
307
+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
308
+ using IteratorB0 =
309
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
310
+ cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
311
+ ElementB, layout::TensorCxRSKx<InterleavedK>,
312
+ ThreadMapB0
313
+ >;
314
+
315
+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
316
+
317
+ // Use fragment iterator for A operand
318
+ using AccumulatorLayout = cutlass::layout::RowMajor;
319
+ using FragmentIteratorA1 =
320
+ cutlass::gemm::warp::MmaTensorOpFragmentIterator<
321
+ cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
322
+ cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
323
+ MmaCore1::Shape::kK, //kBlocksColumn
324
+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
325
+
326
+ /// Define iterators over tiles from scale/bias vectors
327
+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
328
+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
329
+ static int const kElementsPerAccess = 4;
330
+ using IteratorAccumulatorScaleBias =
331
+ cutlass::transform::threadblock::VectorIterator<
332
+ cutlass::transform::threadblock::PredicatedVectorAccessIterator<
333
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
334
+ cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
335
+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
336
+ >;
337
+
338
+ // Warp-level iterators to load scale and bias vectors
339
+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
340
+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
341
+ LayoutScaleBias, InstructionShape, kElementsPerAccess>;
342
+
343
+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
344
+ using IteratorB1 =
345
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
346
+ cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
347
+ ElementB, layout::TensorCxRSKx<InterleavedK>,
348
+ ThreadMapB1
349
+ >;
350
+
351
+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
352
+
353
+
354
+ // Warp-level GEMM components
355
+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
356
+ using MmaPolicy0 = typename MmaCore0::MmaPolicy;
357
+ using MmaPolicy1 = typename MmaCore1::MmaPolicy;
358
+
359
+ // Define the Mma
360
+ using B2bMma = threadblock::B2bImplicitGemmMultistage<
361
+ ThreadblockShape0,
362
+ IteratorA0,
363
+ SmemIteratorA0,
364
+ arch::CacheOperation::Always,
365
+ IteratorB0,
366
+ SmemIteratorB0,
367
+ arch::CacheOperation::Global,
368
+ ThreadblockShape1,
369
+ FragmentIteratorA1,
370
+ IteratorAccumulatorScaleBias,
371
+ FragmentIteratorA1ScaleBias,
372
+ IteratorB1,
373
+ SmemIteratorB1,
374
+ arch::CacheOperation::Global,
375
+ EpilogueOutputOp0,
376
+ MmaPolicy0,
377
+ MmaPolicy1,
378
+ Stages
379
+ >;
380
+
381
+ // Define the epilogue
382
+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
383
+ ThreadblockShape1,
384
+ WarpMmaTensorOp1,
385
+ 1,
386
+ EpilogueOutputOp1,
387
+ EpilogueOutputOp1::kCount,
388
+ InterleavedK
389
+ >::Epilogue;
390
+
391
+ // Define the kernel
392
+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
393
+ B2bMma,
394
+ Epilogue,
395
+ ThreadblockSwizzle,
396
+ conv::Operator::kFprop
397
+ >;
398
+ };
399
+
400
+ /////////////////////////////////////////////////////////////////////////////////////////////////
401
+
402
+ /// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
403
+ /// multistage pipeline.
404
+ template <
405
+ typename ElementA,
406
+ typename LayoutA,
407
+ typename ElementB,
408
+ typename LayoutB,
409
+ typename ElementC,
410
+ typename LayoutC,
411
+ typename ElementAccumulator,
412
+ typename ArchTag,
413
+ typename ThreadblockShape0,
414
+ typename ThreadblockShape1,
415
+ typename WarpShape0,
416
+ typename WarpShape1,
417
+ typename InstructionShape,
418
+ typename EpilogueOutputOp0,
419
+ typename EpilogueOutputOp1,
420
+ typename ThreadblockSwizzle,
421
+ int Stages,
422
+ typename MathOperatorTag
423
+ >
424
+ struct DefaultB2bConv2dFprop <
425
+ ElementA,
426
+ LayoutA,
427
+ ElementB,
428
+ LayoutB,
429
+ ElementC,
430
+ LayoutC,
431
+ ElementAccumulator,
432
+ arch::OpClassTensorOp,
433
+ ArchTag,
434
+ ThreadblockShape0,
435
+ ThreadblockShape1,
436
+ WarpShape0,
437
+ WarpShape1,
438
+ InstructionShape,
439
+ EpilogueOutputOp0,
440
+ EpilogueOutputOp1,
441
+ ThreadblockSwizzle,
442
+ Stages,
443
+ MathOperatorTag,
444
+ IteratorAlgorithm::kOptimized
445
+ > {
446
+
447
+ // Define the core components from GEMM
448
+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
449
+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
450
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
451
+ Stages, MathOperatorTag>;
452
+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
453
+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
454
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
455
+ Stages, MathOperatorTag>;
456
+
457
+ // Define iterators over tiles from the A operand
458
+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
459
+ using IteratorA0 =
460
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
461
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
462
+ ElementA, LayoutA,
463
+ ThreadMapA0
464
+ >;
465
+
466
+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
467
+
468
+ // Define iterators over tiles from the B operand
469
+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
470
+ using IteratorB0 =
471
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
472
+ cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
473
+ ElementB, LayoutB,
474
+ ThreadMapB0
475
+ >;
476
+
477
+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
478
+
479
+ // Use fragment iterator for A operand
480
+ using AccumulatorLayout = cutlass::layout::ColumnMajor;
481
+ using FragmentIteratorA1 =
482
+ cutlass::gemm::warp::MmaTensorOpFragmentIterator<
483
+ cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
484
+ cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
485
+ MmaCore1::Shape::kK, //kBlocksColumn
486
+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
487
+
488
+ /// Define iterators over tiles from scale/bias vectors
489
+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
490
+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
491
+ static int const kElementsPerAccess = 2;
492
+ using IteratorAccumulatorScaleBias =
493
+ cutlass::transform::threadblock::VectorIterator<
494
+ cutlass::transform::threadblock::PredicatedVectorAccessIterator<
495
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
496
+ cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
497
+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
498
+ >;
499
+
500
+ // Warp-level iterators to load scale and bias vectors
501
+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
502
+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
503
+ LayoutScaleBias, InstructionShape, kElementsPerAccess>;
504
+
505
+ // Define iterators over tiles from the B operand
506
+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
507
+ using IteratorB1 =
508
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
509
+ cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
510
+ ElementB, LayoutB,
511
+ ThreadMapB1
512
+ >;
513
+
514
+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
515
+
516
+ // Warp-level GEMM components
517
+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
518
+ using MmaPolicy0 = typename MmaCore0::MmaPolicy;
519
+ using MmaPolicy1 = typename MmaCore1::MmaPolicy;
520
+
521
+ // Define the Mma
522
+ using B2bMma = threadblock::B2bImplicitGemmMultistage<
523
+ ThreadblockShape0,
524
+ IteratorA0,
525
+ SmemIteratorA0,
526
+ arch::CacheOperation::Always,
527
+ IteratorB0,
528
+ SmemIteratorB0,
529
+ arch::CacheOperation::Global,
530
+ ThreadblockShape1,
531
+ FragmentIteratorA1,
532
+ IteratorAccumulatorScaleBias,
533
+ FragmentIteratorA1ScaleBias,
534
+ IteratorB1,
535
+ SmemIteratorB1,
536
+ arch::CacheOperation::Global,
537
+ EpilogueOutputOp0,
538
+ MmaPolicy0,
539
+ MmaPolicy1,
540
+ Stages
541
+ >;
542
+
543
+ // Define the epilogue
544
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
545
+ ThreadblockShape1,
546
+ WarpMmaTensorOp1,
547
+ 1,
548
+ EpilogueOutputOp1,
549
+ EpilogueOutputOp1::kCount
550
+ >::Epilogue;
551
+
552
+ // Define the kernel
553
+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
554
+ B2bMma,
555
+ Epilogue,
556
+ ThreadblockSwizzle,
557
+ conv::Operator::kFprop
558
+ >;
559
+ };
560
+
561
+ /////////////////////////////////////////////////////////////////////////////////////////////////
562
+
563
+ /// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
564
+ // multistage pipeline with interleaved layout.
565
+ template <
566
+ typename ElementA,
567
+ typename ElementB,
568
+ typename ElementC,
569
+ typename LayoutC,
570
+ typename ElementAccumulator,
571
+ typename ArchTag,
572
+ typename ThreadblockShape0,
573
+ typename ThreadblockShape1,
574
+ typename WarpShape0,
575
+ typename WarpShape1,
576
+ typename InstructionShape,
577
+ typename EpilogueOutputOp0,
578
+ typename EpilogueOutputOp1,
579
+ typename ThreadblockSwizzle,
580
+ int Stages,
581
+ typename MathOperatorTag,
582
+ int InterleavedK
583
+ >
584
+ struct DefaultB2bConv2dFprop <
585
+ ElementA,
586
+ layout::TensorNCxHWx<InterleavedK>,
587
+ ElementB,
588
+ layout::TensorCxRSKx<InterleavedK>,
589
+ ElementC,
590
+ LayoutC,
591
+ ElementAccumulator,
592
+ arch::OpClassTensorOp,
593
+ ArchTag,
594
+ ThreadblockShape0,
595
+ ThreadblockShape1,
596
+ WarpShape0,
597
+ WarpShape1,
598
+ InstructionShape,
599
+ EpilogueOutputOp0,
600
+ EpilogueOutputOp1,
601
+ ThreadblockSwizzle,
602
+ Stages,
603
+ MathOperatorTag,
604
+ IteratorAlgorithm::kOptimized
605
+ > {
606
+
607
+ // Define the core components from GEMM
608
+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
609
+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
610
+ ElementB, layout::RowMajorInterleaved<InterleavedK>,
611
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp,
612
+ Stages, MathOperatorTag, true>;
613
+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
614
+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
615
+ ElementB, layout::RowMajorInterleaved<InterleavedK>,
616
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp,
617
+ Stages, MathOperatorTag, true>;
618
+
619
+ // Define iterators over tiles from the A operand
620
+ // Note GEMM shared memory threadmap is used here because conv global memory
621
+ // layout needs to be mapped to fprop which is similar to the crosswise
622
+ // layout which is used by the interleaved GEMM shared memory threadmap.
623
+ // The Interleaved GEMM global memory layout is similar to the congruous
624
+ // layout.
625
+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
626
+ using IteratorA0 =
627
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
628
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
629
+ ElementA, layout::TensorNCxHWx<InterleavedK>,
630
+ ThreadMapA0
631
+ >;
632
+
633
+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
634
+
635
+ // Define iterators over tiles from the B operand
636
+ // Note GEMM shared memory threadmap is used here because conv global memory
637
+ // layout needs to be mapped to fprop which is similar to the crosswise
638
+ // layout which is used by the interleaved GEMM shared memory threadmap.
639
+ // The Interleaved GEMM global memory layout is similar to the congruous
640
+ // layout.
641
+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
642
+ using IteratorB0 =
643
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
644
+ cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
645
+ ElementB, layout::TensorCxRSKx<InterleavedK>,
646
+ ThreadMapB0
647
+ >;
648
+
649
+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
650
+
651
+ // Use fragment iterator for A operand
652
+ using AccumulatorLayout = cutlass::layout::RowMajor;
653
+ using FragmentIteratorA1 =
654
+ cutlass::gemm::warp::MmaTensorOpFragmentIterator<
655
+ cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
656
+ cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
657
+ MmaCore1::Shape::kK, //kBlocksColumn
658
+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>;
659
+
660
+ /// Define iterators over tiles from scale/bias vectors
661
+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
662
+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
663
+ static int const kElementsPerAccess = 4;
664
+ using IteratorAccumulatorScaleBias =
665
+ cutlass::transform::threadblock::VectorIterator<
666
+ cutlass::transform::threadblock::PredicatedVectorAccessIterator<
667
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
668
+ cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
669
+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
670
+ >;
671
+
672
+ // Warp-level iterators to load scale and bias vectors
673
+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
674
+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
675
+ LayoutScaleBias, InstructionShape, kElementsPerAccess>;
676
+
677
+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
678
+ using IteratorB1 =
679
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
680
+ cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
681
+ ElementB, layout::TensorCxRSKx<InterleavedK>,
682
+ ThreadMapB1
683
+ >;
684
+
685
+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
686
+
687
+
688
+ // Warp-level GEMM components
689
+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
690
+ using MmaPolicy0 = typename MmaCore0::MmaPolicy;
691
+ using MmaPolicy1 = typename MmaCore1::MmaPolicy;
692
+
693
+ // Define the Mma
694
+ using B2bMma = threadblock::B2bImplicitGemmMultistage<
695
+ ThreadblockShape0,
696
+ IteratorA0,
697
+ SmemIteratorA0,
698
+ arch::CacheOperation::Always,
699
+ IteratorB0,
700
+ SmemIteratorB0,
701
+ arch::CacheOperation::Global,
702
+ ThreadblockShape1,
703
+ FragmentIteratorA1,
704
+ IteratorAccumulatorScaleBias,
705
+ FragmentIteratorA1ScaleBias,
706
+ IteratorB1,
707
+ SmemIteratorB1,
708
+ arch::CacheOperation::Global,
709
+ EpilogueOutputOp0,
710
+ MmaPolicy0,
711
+ MmaPolicy1,
712
+ Stages
713
+ >;
714
+
715
+ // Define the epilogue
716
+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
717
+ ThreadblockShape1,
718
+ WarpMmaTensorOp1,
719
+ 1,
720
+ EpilogueOutputOp1,
721
+ EpilogueOutputOp1::kCount,
722
+ InterleavedK
723
+ >::Epilogue;
724
+
725
+ // Define the kernel
726
+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
727
+ B2bMma,
728
+ Epilogue,
729
+ ThreadblockSwizzle,
730
+ conv::Operator::kFprop
731
+ >;
732
+ };
733
+
734
+ /////////////////////////////////////////////////////////////////////////////////////////////////
735
+
736
+ } // namespace kernel
737
+ } // namespace conv
738
+ } // namespace cutlass
739
+
740
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h ADDED
@@ -0,0 +1,817 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief
34
+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
35
+ matrix multiply-add with the appropriate threadblock-scoped epilogue.
36
+ */
37
+
38
+ #pragma once
39
+
40
+ #include "cutlass/cutlass.h"
41
+ #include "cutlass/conv/kernel/default_conv2d.h"
42
+
43
+ #include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
44
+ #include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
45
+ #include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
46
+ #include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
47
+
48
+ #include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
49
+ #include "cutlass/transform/threadblock/vector_iterator.h"
50
+ #include "cutlass/transform/warp/vector_fragment_iterator.h"
51
+
52
+ #include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
53
+
54
+ #include "kernel/default_b2b_conv2d_fprop.h"
55
+ #include "kernel/b2b_implicit_gemm_convolution.h"
56
+ #include "threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h"
57
+
58
+ /////////////////////////////////////////////////////////////////////////////////////////////////
59
+
60
+ namespace cutlass {
61
+ namespace conv {
62
+ namespace kernel {
63
+
64
+ /////////////////////////////////////////////////////////////////////////////////////////////////
65
+
66
+ /// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm
67
+ /// and 2 stage pipeline.
68
+ /// Accumulator will be staged in shared memory.
69
+ template <
70
+ typename ElementA,
71
+ typename LayoutA,
72
+ typename ElementB,
73
+ typename LayoutB,
74
+ typename ElementC,
75
+ typename LayoutC,
76
+ typename ElementAccumulator,
77
+ typename ArchTag,
78
+ typename ThreadblockShape0,
79
+ typename ThreadblockShape1,
80
+ typename WarpShape0,
81
+ typename WarpShape1,
82
+ typename InstructionShape,
83
+ typename EpilogueOutputOp0,
84
+ typename EpilogueOutputOp1,
85
+ typename ThreadblockSwizzle,
86
+ typename MathOperatorTag
87
+ >
88
+ struct DefaultB2bConv2dFprop <
89
+ ElementA,
90
+ LayoutA,
91
+ ElementB,
92
+ LayoutB,
93
+ ElementC,
94
+ LayoutC,
95
+ ElementAccumulator,
96
+ arch::OpClassTensorOp,
97
+ ArchTag,
98
+ ThreadblockShape0,
99
+ ThreadblockShape1,
100
+ WarpShape0,
101
+ WarpShape1,
102
+ InstructionShape,
103
+ EpilogueOutputOp0,
104
+ EpilogueOutputOp1,
105
+ ThreadblockSwizzle,
106
+ 2,
107
+ MathOperatorTag,
108
+ IteratorAlgorithm::kAnalytic,
109
+ true
110
+ > {
111
+
112
+ // Define the core components from GEMM
113
+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
114
+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
115
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
116
+ 2, MathOperatorTag>;
117
+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
118
+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
119
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
120
+ 2, MathOperatorTag>;
121
+
122
+ // Define iterators over tiles from the A operand
123
+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
124
+ using IteratorA0 =
125
+ cutlass::conv::threadblock::TileIterator<
126
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
127
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
128
+ ElementA, LayoutA,
129
+ ThreadMapA0
130
+ >
131
+ >;
132
+
133
+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
134
+
135
+ // Define iterators over tiles from the B operand
136
+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
137
+ using IteratorB0 =
138
+ cutlass::conv::threadblock::TileIterator<
139
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
140
+ cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
141
+ ElementB, LayoutB,
142
+ ThreadMapB0
143
+ >
144
+ >;
145
+
146
+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
147
+
148
+ /// Define iterators over tiles from scale/bias vectors
149
+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
150
+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
151
+ static int const kElementsPerAccess = 2;
152
+ using IteratorAccumulatorScaleBias =
153
+ cutlass::transform::threadblock::VectorIterator<
154
+ cutlass::transform::threadblock::PredicatedVectorAccessIterator<
155
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
156
+ cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
157
+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
158
+ >;
159
+
160
+ // Define iterators over tiles from the B operand
161
+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
162
+ using IteratorB1 =
163
+ cutlass::conv::threadblock::TileIterator<
164
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
165
+ cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
166
+ ElementB, LayoutB,
167
+ ThreadMapB1
168
+ >
169
+ >;
170
+
171
+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
172
+
173
+ // Warp-level GEMM components
174
+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
175
+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
176
+ using MmaPolicy0 = typename MmaCore0::MmaPolicy;
177
+ using MmaPolicy1 = typename MmaCore1::MmaPolicy;
178
+
179
+ // Use fragment iterator for the accumulator
180
+ using SmemAccumulatorLayout = cutlass::layout::RowMajor;
181
+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
182
+ WarpShape0, InstructionShape,
183
+ ElementAccumulator,
184
+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
185
+ SmemAccumulatorLayout
186
+ >;
187
+
188
+ // Store Accumulator tiles to Shared Memory
189
+ using SmemIteratorD0 =
190
+ cutlass::epilogue::warp::TileIteratorTensorOp<
191
+ WarpShape0,
192
+ InstructionShape,
193
+ ElementC,
194
+ SmemAccumulatorLayout
195
+ >;
196
+
197
+ static int const kThreadCount = 32;
198
+ // load warp tile from Shared Memory accumulator
199
+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
200
+ MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
201
+ ElementA, SmemAccumulatorLayout,
202
+ MatrixShape<InstructionShape::kM, InstructionShape::kK>,
203
+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
204
+
205
+ // Define the Mma
206
+ using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator<
207
+ ThreadblockShape0,
208
+ IteratorA0,
209
+ SmemIteratorA0,
210
+ IteratorB0,
211
+ SmemIteratorB0,
212
+ IteratorAccumulatorScaleBias,
213
+ FragmentIteratorAccumulator,
214
+ SmemIteratorD0,
215
+ ThreadblockShape1,
216
+ WarpIteratorA1,
217
+ IteratorB1,
218
+ SmemIteratorB1,
219
+ ElementC,
220
+ LayoutC,
221
+ EpilogueOutputOp0,
222
+ MmaPolicy0,
223
+ MmaPolicy1
224
+ >;
225
+
226
+ // Define the epilogue
227
+ using Epilogue = typename detail::DefaultConvEpilogue<
228
+ ArchTag,
229
+ ThreadblockShape1,
230
+ WarpMmaTensorOp1,
231
+ 1,
232
+ EpilogueOutputOp1
233
+ >::Epilogue;
234
+
235
+ // Define the kernel
236
+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
237
+ B2bMma,
238
+ Epilogue,
239
+ ThreadblockSwizzle,
240
+ conv::Operator::kFprop
241
+ >;
242
+ };
243
+
244
+ /////////////////////////////////////////////////////////////////////////////////////////////////
245
+
246
+ /// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage
247
+ /// pipeline with interleaved layout.
248
+ /// Accumulator will be staged in shared memory.
249
+ template <
250
+ typename ElementA,
251
+ typename ElementB,
252
+ typename ElementC,
253
+ typename LayoutC,
254
+ typename ElementAccumulator,
255
+ typename ArchTag,
256
+ typename ThreadblockShape0,
257
+ typename ThreadblockShape1,
258
+ typename WarpShape0,
259
+ typename WarpShape1,
260
+ typename InstructionShape,
261
+ typename EpilogueOutputOp0,
262
+ typename EpilogueOutputOp1,
263
+ typename ThreadblockSwizzle,
264
+ typename MathOperatorTag,
265
+ int InterleavedK
266
+ >
267
+ struct DefaultB2bConv2dFprop <
268
+ ElementA,
269
+ layout::TensorNCxHWx<InterleavedK>,
270
+ ElementB,
271
+ layout::TensorCxRSKx<InterleavedK>,
272
+ ElementC,
273
+ LayoutC,
274
+ ElementAccumulator,
275
+ arch::OpClassTensorOp,
276
+ ArchTag,
277
+ ThreadblockShape0,
278
+ ThreadblockShape1,
279
+ WarpShape0,
280
+ WarpShape1,
281
+ InstructionShape,
282
+ EpilogueOutputOp0,
283
+ EpilogueOutputOp1,
284
+ ThreadblockSwizzle,
285
+ 2,
286
+ MathOperatorTag,
287
+ IteratorAlgorithm::kAnalytic,
288
+ true
289
+ > {
290
+
291
+ // Define the core components from GEMM
292
+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
293
+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
294
+ ElementB, layout::RowMajorInterleaved<InterleavedK>,
295
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp,
296
+ 2, MathOperatorTag, true>;
297
+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
298
+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
299
+ ElementB, layout::RowMajorInterleaved<InterleavedK>,
300
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp,
301
+ 2, MathOperatorTag, true>;
302
+
303
+ // Define iterators over tiles from the A operand
304
+ // Note GEMM shared memory threadmap is used here because conv global memory
305
+ // layout needs to be mapped to fprop which is similar to the crosswise
306
+ // layout which is used by the interleaved GEMM shared memory threadmap.
307
+ // The Interleaved GEMM global memory layout is similar to the congruous
308
+ // layout.
309
+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
310
+ using IteratorA0 =
311
+ cutlass::conv::threadblock::TileIterator<
312
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
313
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
314
+ ElementA, layout::TensorNCxHWx<InterleavedK>,
315
+ ThreadMapA0
316
+ >
317
+ >;
318
+
319
+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
320
+
321
+ // Define iterators over tiles from the B operand
322
+ // Note GEMM shared memory threadmap is used here because conv global memory
323
+ // layout needs to be mapped to fprop which is similar to the crosswise
324
+ // layout which is used by the interleaved GEMM shared memory threadmap.
325
+ // The Interleaved GEMM global memory layout is similar to the congruous
326
+ // layout.
327
+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
328
+ using IteratorB0 =
329
+ cutlass::conv::threadblock::TileIterator<
330
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
331
+ cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
332
+ ElementB, layout::TensorCxRSKx<InterleavedK>,
333
+ ThreadMapB0
334
+ >
335
+ >;
336
+
337
+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
338
+
339
+ /// Define iterators over tiles from scale/bias vectors
340
+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
341
+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
342
+ static int const kElementsPerAccess = 4; //For interleaved layout
343
+ using IteratorAccumulatorScaleBias =
344
+ cutlass::transform::threadblock::VectorIterator<
345
+ cutlass::transform::threadblock::PredicatedVectorAccessIterator<
346
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
347
+ cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
348
+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
349
+ >;
350
+
351
+ // Define iterators over tiles from the B operand
352
+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
353
+ using IteratorB1 =
354
+ cutlass::conv::threadblock::TileIterator<
355
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
356
+ cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
357
+ ElementB, layout::TensorCxRSKx<InterleavedK>,
358
+ ThreadMapB1
359
+ >
360
+ >;
361
+
362
+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
363
+
364
+ // Warp-level GEMM components
365
+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
366
+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
367
+ using MmaPolicy0 = typename MmaCore0::MmaPolicy;
368
+ using MmaPolicy1 = typename MmaCore1::MmaPolicy;
369
+
370
+ // Use fragment iterator for the accumulator
371
+ using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
372
+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
373
+ WarpShape0, InstructionShape,
374
+ ElementAccumulator,
375
+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
376
+ SmemAccumulatorLayout
377
+ >;
378
+
379
+
380
+ // Store Accumulator tiles to Shared Memory
381
+ using SmemIteratorD0 =
382
+ cutlass::epilogue::warp::TileIteratorTensorOp<
383
+ WarpShape0,
384
+ InstructionShape,
385
+ ElementC,
386
+ SmemAccumulatorLayout
387
+ >;
388
+
389
+ static int const kThreadCount = 32;
390
+ // load warp tile from Shared Memory accumulator
391
+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
392
+ MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
393
+ ElementA, SmemAccumulatorLayout,
394
+ MatrixShape<InstructionShape::kM, InstructionShape::kK>,
395
+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
396
+
397
+ // Define the Mma
398
+ using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator<
399
+ ThreadblockShape0,
400
+ IteratorA0,
401
+ SmemIteratorA0,
402
+ IteratorB0,
403
+ SmemIteratorB0,
404
+ IteratorAccumulatorScaleBias,
405
+ FragmentIteratorAccumulator,
406
+ SmemIteratorD0,
407
+ ThreadblockShape1,
408
+ WarpIteratorA1,
409
+ IteratorB1,
410
+ SmemIteratorB1,
411
+ ElementC,
412
+ LayoutC,
413
+ EpilogueOutputOp0,
414
+ MmaPolicy0,
415
+ MmaPolicy1
416
+ >;
417
+
418
+ // Define the epilogue
419
+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
420
+ ThreadblockShape1,
421
+ WarpMmaTensorOp1,
422
+ 1,
423
+ EpilogueOutputOp1,
424
+ EpilogueOutputOp1::kCount,
425
+ InterleavedK
426
+ >::Epilogue;
427
+
428
+ // Define the kernel
429
+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
430
+ B2bMma,
431
+ Epilogue,
432
+ ThreadblockSwizzle,
433
+ conv::Operator::kFprop
434
+ >;
435
+ };
436
+
437
+ /////////////////////////////////////////////////////////////////////////////////////////////////
438
+
439
+ /// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm
440
+ /// and 2 stage pipeline.
441
+ /// Accumulator will be staged in shared memory.
442
+ template <
443
+ typename ElementA,
444
+ typename LayoutA,
445
+ typename ElementB,
446
+ typename LayoutB,
447
+ typename ElementC,
448
+ typename LayoutC,
449
+ typename ElementAccumulator,
450
+ typename ArchTag,
451
+ typename ThreadblockShape0,
452
+ typename ThreadblockShape1,
453
+ typename WarpShape0,
454
+ typename WarpShape1,
455
+ typename InstructionShape,
456
+ typename EpilogueOutputOp0,
457
+ typename EpilogueOutputOp1,
458
+ typename ThreadblockSwizzle,
459
+ typename MathOperatorTag
460
+ >
461
+ struct DefaultB2bConv2dFprop <
462
+ ElementA,
463
+ LayoutA,
464
+ ElementB,
465
+ LayoutB,
466
+ ElementC,
467
+ LayoutC,
468
+ ElementAccumulator,
469
+ arch::OpClassTensorOp,
470
+ ArchTag,
471
+ ThreadblockShape0,
472
+ ThreadblockShape1,
473
+ WarpShape0,
474
+ WarpShape1,
475
+ InstructionShape,
476
+ EpilogueOutputOp0,
477
+ EpilogueOutputOp1,
478
+ ThreadblockSwizzle,
479
+ 2,
480
+ MathOperatorTag,
481
+ IteratorAlgorithm::kOptimized,
482
+ true
483
+ > {
484
+
485
+ // Define the core components from GEMM
486
+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
487
+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
488
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
489
+ 2, MathOperatorTag>;
490
+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
491
+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
492
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
493
+ 2, MathOperatorTag>;
494
+
495
+ // Define iterators over tiles from the A operand
496
+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
497
+ using IteratorA0 =
498
+ cutlass::conv::threadblock::TileIterator<
499
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
500
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
501
+ ElementA, LayoutA,
502
+ ThreadMapA0
503
+ >
504
+ >;
505
+
506
+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
507
+
508
+ // Define iterators over tiles from the B operand
509
+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
510
+ using IteratorB0 =
511
+ cutlass::conv::threadblock::TileIterator<
512
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
513
+ cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
514
+ ElementB, LayoutB,
515
+ ThreadMapB0
516
+ >
517
+ >;
518
+
519
+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
520
+
521
+ /// Define iterators over tiles from scale/bias vectors
522
+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
523
+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
524
+ static int const kElementsPerAccess = 2;
525
+ using IteratorAccumulatorScaleBias =
526
+ cutlass::transform::threadblock::VectorIterator<
527
+ cutlass::transform::threadblock::PredicatedVectorAccessIterator<
528
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
529
+ cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
530
+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
531
+ >;
532
+
533
+ // Define iterators over tiles from the B operand
534
+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
535
+ using IteratorB1 =
536
+ cutlass::conv::threadblock::TileIterator<
537
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
538
+ cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
539
+ ElementB, LayoutB,
540
+ ThreadMapB1
541
+ >
542
+ >;
543
+
544
+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
545
+
546
+ // Warp-level GEMM components
547
+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
548
+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
549
+ using MmaPolicy0 = typename MmaCore0::MmaPolicy;
550
+ using MmaPolicy1 = typename MmaCore1::MmaPolicy;
551
+
552
+ // Use fragment iterator for the accumulator
553
+ using SmemAccumulatorLayout = cutlass::layout::RowMajor;
554
+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
555
+ WarpShape0, InstructionShape,
556
+ ElementAccumulator,
557
+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
558
+ SmemAccumulatorLayout
559
+ >;
560
+
561
+ // Store Accumulator tiles to Shared Memory
562
+ using SmemIteratorD0 =
563
+ cutlass::epilogue::warp::TileIteratorTensorOp<
564
+ WarpShape0,
565
+ InstructionShape,
566
+ ElementC,
567
+ SmemAccumulatorLayout
568
+ >;
569
+
570
+ static int const kThreadCount = 32;
571
+ // load warp tile from Shared Memory accumulator
572
+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
573
+ MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
574
+ ElementA, SmemAccumulatorLayout,
575
+ MatrixShape<InstructionShape::kM, InstructionShape::kK>,
576
+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
577
+
578
+ // Define the Mma
579
+ using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator<
580
+ ThreadblockShape0,
581
+ IteratorA0,
582
+ SmemIteratorA0,
583
+ IteratorB0,
584
+ SmemIteratorB0,
585
+ IteratorAccumulatorScaleBias,
586
+ FragmentIteratorAccumulator,
587
+ SmemIteratorD0,
588
+ ThreadblockShape1,
589
+ WarpIteratorA1,
590
+ IteratorB1,
591
+ SmemIteratorB1,
592
+ ElementC,
593
+ LayoutC,
594
+ EpilogueOutputOp0,
595
+ MmaPolicy0,
596
+ MmaPolicy1
597
+ >;
598
+
599
+ // Define the epilogue
600
+ using Epilogue = typename detail::DefaultConvEpilogue<
601
+ ArchTag,
602
+ ThreadblockShape1,
603
+ WarpMmaTensorOp1,
604
+ 1,
605
+ EpilogueOutputOp1
606
+ >::Epilogue;
607
+
608
+ // Define the kernel
609
+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
610
+ B2bMma,
611
+ Epilogue,
612
+ ThreadblockSwizzle,
613
+ conv::Operator::kFprop
614
+ >;
615
+ };
616
+
617
+ /////////////////////////////////////////////////////////////////////////////////////////////////
618
+
619
+ /// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage
620
+ /// pipeline with interleaved layout.
621
+ /// Accumulator will be staged in shared memory.
622
+ template <
623
+ typename ElementA,
624
+ typename ElementB,
625
+ typename ElementC,
626
+ typename LayoutC,
627
+ typename ElementAccumulator,
628
+ typename ArchTag,
629
+ typename ThreadblockShape0,
630
+ typename ThreadblockShape1,
631
+ typename WarpShape0,
632
+ typename WarpShape1,
633
+ typename InstructionShape,
634
+ typename EpilogueOutputOp0,
635
+ typename EpilogueOutputOp1,
636
+ typename ThreadblockSwizzle,
637
+ typename MathOperatorTag,
638
+ int InterleavedK
639
+ >
640
+ struct DefaultB2bConv2dFprop <
641
+ ElementA,
642
+ layout::TensorNCxHWx<InterleavedK>,
643
+ ElementB,
644
+ layout::TensorCxRSKx<InterleavedK>,
645
+ ElementC,
646
+ LayoutC,
647
+ ElementAccumulator,
648
+ arch::OpClassTensorOp,
649
+ ArchTag,
650
+ ThreadblockShape0,
651
+ ThreadblockShape1,
652
+ WarpShape0,
653
+ WarpShape1,
654
+ InstructionShape,
655
+ EpilogueOutputOp0,
656
+ EpilogueOutputOp1,
657
+ ThreadblockSwizzle,
658
+ 2,
659
+ MathOperatorTag,
660
+ IteratorAlgorithm::kOptimized,
661
+ true
662
+ > {
663
+
664
+ // Define the core components from GEMM
665
+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
666
+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
667
+ ElementB, layout::RowMajorInterleaved<InterleavedK>,
668
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp,
669
+ 2, MathOperatorTag, true>;
670
+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
671
+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
672
+ ElementB, layout::RowMajorInterleaved<InterleavedK>,
673
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp,
674
+ 2, MathOperatorTag, true>;
675
+
676
+ // Define iterators over tiles from the A operand
677
+ // Note GEMM shared memory threadmap is used here because conv global memory
678
+ // layout needs to be mapped to fprop which is similar to the crosswise
679
+ // layout which is used by the interleaved GEMM shared memory threadmap.
680
+ // The Interleaved GEMM global memory layout is similar to the congruous
681
+ // layout.
682
+
683
+ // Define iterators over tiles from the A operand
684
+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
685
+ using IteratorA0 =
686
+ cutlass::conv::threadblock::TileIterator<
687
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
688
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
689
+ ElementA, layout::TensorNCxHWx<InterleavedK>,
690
+ ThreadMapA0
691
+ >
692
+ >;
693
+
694
+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
695
+
696
+ // Define iterators over tiles from the B operand
697
+ // Note GEMM shared memory threadmap is used here because conv global memory
698
+ // layout needs to be mapped to fprop which is similar to the crosswise
699
+ // layout which is used by the interleaved GEMM shared memory threadmap.
700
+ // The Interleaved GEMM global memory layout is similar to the congruous
701
+ // layout.
702
+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
703
+ using IteratorB0 =
704
+ cutlass::conv::threadblock::TileIterator<
705
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
706
+ cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
707
+ ElementB, layout::TensorCxRSKx<InterleavedK>,
708
+ ThreadMapB0
709
+ >
710
+ >;
711
+
712
+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
713
+
714
+ /// Define iterators over tiles from scale/bias vectors
715
+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
716
+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
717
+ static int const kElementsPerAccess = 4; //For interleaved layout
718
+ using IteratorAccumulatorScaleBias =
719
+ cutlass::transform::threadblock::VectorIterator<
720
+ cutlass::transform::threadblock::PredicatedVectorAccessIterator<
721
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
722
+ cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
723
+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
724
+ >;
725
+
726
+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
727
+ using IteratorB1 =
728
+ cutlass::conv::threadblock::TileIterator<
729
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
730
+ cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
731
+ ElementB, layout::TensorCxRSKx<InterleavedK>,
732
+ ThreadMapB1
733
+ >
734
+ >;
735
+
736
+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
737
+
738
+ // Warp-level GEMM components
739
+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
740
+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
741
+ using MmaPolicy0 = typename MmaCore0::MmaPolicy;
742
+ using MmaPolicy1 = typename MmaCore1::MmaPolicy;
743
+
744
+ // Use fragment iterator for the accumulator
745
+ using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
746
+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
747
+ WarpShape0, InstructionShape,
748
+ ElementAccumulator,
749
+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
750
+ SmemAccumulatorLayout
751
+ >;
752
+
753
+
754
+ // Store Accumulator tiles to Shared Memory
755
+ using SmemIteratorD0 =
756
+ cutlass::epilogue::warp::TileIteratorTensorOp<
757
+ WarpShape0,
758
+ InstructionShape,
759
+ ElementC,
760
+ SmemAccumulatorLayout
761
+ >;
762
+
763
+ static int const kThreadCount = 32;
764
+ // load warp tile from Shared Memory accumulator
765
+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
766
+ MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
767
+ ElementA, SmemAccumulatorLayout,
768
+ MatrixShape<InstructionShape::kM, InstructionShape::kK>,
769
+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
770
+
771
+ // Define the Mma
772
+ using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator<
773
+ ThreadblockShape0,
774
+ IteratorA0,
775
+ SmemIteratorA0,
776
+ IteratorB0,
777
+ SmemIteratorB0,
778
+ IteratorAccumulatorScaleBias,
779
+ FragmentIteratorAccumulator,
780
+ SmemIteratorD0,
781
+ ThreadblockShape1,
782
+ WarpIteratorA1,
783
+ IteratorB1,
784
+ SmemIteratorB1,
785
+ ElementC,
786
+ LayoutC,
787
+ EpilogueOutputOp0,
788
+ MmaPolicy0,
789
+ MmaPolicy1
790
+ >;
791
+
792
+ // Define the epilogue
793
+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
794
+ ThreadblockShape1,
795
+ WarpMmaTensorOp1,
796
+ 1,
797
+ EpilogueOutputOp1,
798
+ EpilogueOutputOp1::kCount,
799
+ InterleavedK
800
+ >::Epilogue;
801
+
802
+ // Define the kernel
803
+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
804
+ B2bMma,
805
+ Epilogue,
806
+ ThreadblockSwizzle,
807
+ conv::Operator::kFprop
808
+ >;
809
+ };
810
+
811
+ /////////////////////////////////////////////////////////////////////////////////////////////////
812
+
813
+ } // namespace kernel
814
+ } // namespace conv
815
+ } // namespace cutlass
816
+
817
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h ADDED
@@ -0,0 +1,804 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief
34
+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped
35
+ matrix multiply-add with the appropriate threadblock-scoped epilogue.
36
+ */
37
+
38
+ #pragma once
39
+
40
+ #include "cutlass/cutlass.h"
41
+ #include "cutlass/conv/kernel/default_conv2d.h"
42
+
43
+ #include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h"
44
+ #include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h"
45
+ #include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h"
46
+ #include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h"
47
+
48
+ #include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
49
+ #include "cutlass/transform/threadblock/vector_iterator.h"
50
+ #include "cutlass/transform/warp/vector_fragment_iterator.h"
51
+
52
+ #include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
53
+
54
+ #include "kernel/default_b2b_conv2d_fprop.h"
55
+ #include "kernel/b2b_implicit_gemm_convolution.h"
56
+ #include "threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h"
57
+
58
+ /////////////////////////////////////////////////////////////////////////////////////////////////
59
+
60
+ namespace cutlass {
61
+ namespace conv {
62
+ namespace kernel {
63
+
64
+ /////////////////////////////////////////////////////////////////////////////////////////////////
65
+
66
+ /// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
67
+ /// pipeline.
68
+ /// Accumulator will be staged in shared memory.
69
+ template <
70
+ typename ElementA,
71
+ typename LayoutA,
72
+ typename ElementB,
73
+ typename LayoutB,
74
+ typename ElementC,
75
+ typename LayoutC,
76
+ typename ElementAccumulator,
77
+ typename ArchTag,
78
+ typename ThreadblockShape0,
79
+ typename ThreadblockShape1,
80
+ typename WarpShape0,
81
+ typename WarpShape1,
82
+ typename InstructionShape,
83
+ typename EpilogueOutputOp0,
84
+ typename EpilogueOutputOp1,
85
+ typename ThreadblockSwizzle,
86
+ int Stages,
87
+ typename MathOperatorTag
88
+ >
89
+ struct DefaultB2bConv2dFprop <
90
+ ElementA,
91
+ LayoutA,
92
+ ElementB,
93
+ LayoutB,
94
+ ElementC,
95
+ LayoutC,
96
+ ElementAccumulator,
97
+ arch::OpClassTensorOp,
98
+ ArchTag,
99
+ ThreadblockShape0,
100
+ ThreadblockShape1,
101
+ WarpShape0,
102
+ WarpShape1,
103
+ InstructionShape,
104
+ EpilogueOutputOp0,
105
+ EpilogueOutputOp1,
106
+ ThreadblockSwizzle,
107
+ Stages,
108
+ MathOperatorTag,
109
+ IteratorAlgorithm::kAnalytic,
110
+ true
111
+ > {
112
+
113
+ // Define the core components from GEMM
114
+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
115
+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
116
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
117
+ Stages, MathOperatorTag>;
118
+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
119
+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
120
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
121
+ Stages, MathOperatorTag>;
122
+
123
+ // Define iterators over tiles from the A operand
124
+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
125
+ using IteratorA0 =
126
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
127
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
128
+ ElementA, LayoutA,
129
+ ThreadMapA0
130
+ >;
131
+
132
+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
133
+
134
+ // Define iterators over tiles from the B operand
135
+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
136
+ using IteratorB0 =
137
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
138
+ cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
139
+ ElementB, LayoutB,
140
+ ThreadMapB0
141
+ >;
142
+
143
+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
144
+
145
+ /// Define iterators over tiles from scale/bias vectors
146
+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
147
+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
148
+ static int const kElementsPerAccess = 2;
149
+ using IteratorAccumulatorScaleBias =
150
+ cutlass::transform::threadblock::VectorIterator<
151
+ cutlass::transform::threadblock::PredicatedVectorAccessIterator<
152
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
153
+ cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
154
+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
155
+ >;
156
+
157
+ // Define iterators over tiles from the B operand
158
+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
159
+ using IteratorB1 =
160
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
161
+ cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
162
+ ElementB, LayoutB,
163
+ ThreadMapB1
164
+ >;
165
+
166
+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
167
+
168
+ // Warp-level GEMM components
169
+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
170
+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
171
+ using MmaPolicy0 = typename MmaCore0::MmaPolicy;
172
+ using MmaPolicy1 = typename MmaCore1::MmaPolicy;
173
+
174
+ // Use fragment iterator for the accumulator
175
+ using SmemAccumulatorLayout = cutlass::layout::RowMajor;
176
+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
177
+ WarpShape0, InstructionShape,
178
+ ElementAccumulator,
179
+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
180
+ SmemAccumulatorLayout
181
+ >;
182
+
183
+ // Store Accumulator tiles to Shared Memory
184
+ using SmemIteratorD0 =
185
+ cutlass::epilogue::warp::TileIteratorTensorOp<
186
+ WarpShape0,
187
+ InstructionShape,
188
+ ElementC,
189
+ SmemAccumulatorLayout
190
+ >;
191
+
192
+ static int const kThreadCount = 32;
193
+ // load warp tile from Shared Memory accumulator
194
+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
195
+ MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
196
+ ElementA, SmemAccumulatorLayout,
197
+ MatrixShape<InstructionShape::kM, InstructionShape::kK>,
198
+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
199
+
200
+ // Define the Mma
201
+ using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator<
202
+ ThreadblockShape0,
203
+ IteratorA0,
204
+ SmemIteratorA0,
205
+ arch::CacheOperation::Always,
206
+ IteratorB0,
207
+ SmemIteratorB0,
208
+ arch::CacheOperation::Global,
209
+ IteratorAccumulatorScaleBias,
210
+ FragmentIteratorAccumulator,
211
+ SmemIteratorD0,
212
+ ThreadblockShape1,
213
+ WarpIteratorA1,
214
+ IteratorB1,
215
+ SmemIteratorB1,
216
+ arch::CacheOperation::Global,
217
+ EpilogueOutputOp0,
218
+ MmaPolicy0,
219
+ MmaPolicy1,
220
+ Stages
221
+ >;
222
+
223
+ // Define the epilogue
224
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
225
+ ThreadblockShape1,
226
+ WarpMmaTensorOp1,
227
+ 1,
228
+ EpilogueOutputOp1,
229
+ EpilogueOutputOp1::kCount
230
+ >::Epilogue;
231
+
232
+ // Define the kernel
233
+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
234
+ B2bMma,
235
+ Epilogue,
236
+ ThreadblockSwizzle,
237
+ conv::Operator::kFprop
238
+ >;
239
+ };
240
+
241
+ /////////////////////////////////////////////////////////////////////////////////////////////////
242
+
243
+ /// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage
244
+ /// pipeline with interleaved layout.
245
+ /// Accumulator will be staged in shared memory.
246
+ template <
247
+ typename ElementA,
248
+ typename ElementB,
249
+ typename ElementC,
250
+ typename LayoutC,
251
+ typename ElementAccumulator,
252
+ typename ArchTag,
253
+ typename ThreadblockShape0,
254
+ typename ThreadblockShape1,
255
+ typename WarpShape0,
256
+ typename WarpShape1,
257
+ typename InstructionShape,
258
+ typename EpilogueOutputOp0,
259
+ typename EpilogueOutputOp1,
260
+ typename ThreadblockSwizzle,
261
+ int Stages,
262
+ typename MathOperatorTag,
263
+ int InterleavedK
264
+ >
265
+ struct DefaultB2bConv2dFprop <
266
+ ElementA,
267
+ layout::TensorNCxHWx<InterleavedK>,
268
+ ElementB,
269
+ layout::TensorCxRSKx<InterleavedK>,
270
+ ElementC,
271
+ LayoutC,
272
+ ElementAccumulator,
273
+ arch::OpClassTensorOp,
274
+ ArchTag,
275
+ ThreadblockShape0,
276
+ ThreadblockShape1,
277
+ WarpShape0,
278
+ WarpShape1,
279
+ InstructionShape,
280
+ EpilogueOutputOp0,
281
+ EpilogueOutputOp1,
282
+ ThreadblockSwizzle,
283
+ Stages,
284
+ MathOperatorTag,
285
+ IteratorAlgorithm::kAnalytic,
286
+ true
287
+ > {
288
+
289
+ // Define the core components from GEMM
290
+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
291
+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
292
+ ElementB, layout::RowMajorInterleaved<InterleavedK>,
293
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp,
294
+ Stages, MathOperatorTag, true>;
295
+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
296
+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
297
+ ElementB, layout::RowMajorInterleaved<InterleavedK>,
298
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp,
299
+ Stages, MathOperatorTag, true>;
300
+
301
+ // Define iterators over tiles from the A operand
302
+ // Note GEMM shared memory threadmap is used here because conv global memory
303
+ // layout needs to be mapped to fprop which is similar to the crosswise
304
+ // layout which is used by the interleaved GEMM shared memory threadmap.
305
+ // The Interleaved GEMM global memory layout is similar to the congruous
306
+ // layout.
307
+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
308
+ using IteratorA0 =
309
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<
310
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
311
+ ElementA, layout::TensorNCxHWx<InterleavedK>,
312
+ ThreadMapA0
313
+ >;
314
+
315
+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
316
+
317
+ // Define iterators over tiles from the B operand
318
+ // Note GEMM shared memory threadmap is used here because conv global memory
319
+ // layout needs to be mapped to fprop which is similar to the crosswise
320
+ // layout which is used by the interleaved GEMM shared memory threadmap.
321
+ // The Interleaved GEMM global memory layout is similar to the congruous
322
+ // layout.
323
+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
324
+ using IteratorB0 =
325
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
326
+ cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
327
+ ElementB, layout::TensorCxRSKx<InterleavedK>,
328
+ ThreadMapB0
329
+ >;
330
+
331
+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
332
+
333
+ /// Define iterators over tiles from scale/bias vectors
334
+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
335
+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
336
+ static int const kElementsPerAccess = 4;
337
+ using IteratorAccumulatorScaleBias =
338
+ cutlass::transform::threadblock::VectorIterator<
339
+ cutlass::transform::threadblock::PredicatedVectorAccessIterator<
340
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
341
+ cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
342
+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
343
+ >;
344
+
345
+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
346
+ using IteratorB1 =
347
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<
348
+ cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
349
+ ElementB, layout::TensorCxRSKx<InterleavedK>,
350
+ ThreadMapB1
351
+ >;
352
+
353
+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
354
+
355
+ // Warp-level GEMM components
356
+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
357
+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
358
+ using MmaPolicy0 = typename MmaCore0::MmaPolicy;
359
+ using MmaPolicy1 = typename MmaCore1::MmaPolicy;
360
+
361
+ // Use fragment iterator for the accumulator
362
+ using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
363
+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
364
+ WarpShape0, InstructionShape,
365
+ ElementAccumulator,
366
+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
367
+ SmemAccumulatorLayout
368
+ >;
369
+
370
+
371
+ // Store Accumulator tiles to Shared Memory
372
+ using SmemIteratorD0 =
373
+ cutlass::epilogue::warp::TileIteratorTensorOp<
374
+ WarpShape0,
375
+ InstructionShape,
376
+ ElementC,
377
+ SmemAccumulatorLayout
378
+ >;
379
+
380
+ static int const kThreadCount = 32;
381
+ // load warp tile from Shared Memory accumulator
382
+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
383
+ MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
384
+ ElementA, SmemAccumulatorLayout,
385
+ MatrixShape<InstructionShape::kM, InstructionShape::kK>,
386
+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
387
+
388
+ // Define the Mma
389
+ using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator<
390
+ ThreadblockShape0,
391
+ IteratorA0,
392
+ SmemIteratorA0,
393
+ arch::CacheOperation::Always,
394
+ IteratorB0,
395
+ SmemIteratorB0,
396
+ arch::CacheOperation::Global,
397
+ IteratorAccumulatorScaleBias,
398
+ FragmentIteratorAccumulator,
399
+ SmemIteratorD0,
400
+ ThreadblockShape1,
401
+ WarpIteratorA1,
402
+ IteratorB1,
403
+ SmemIteratorB1,
404
+ arch::CacheOperation::Global,
405
+ EpilogueOutputOp0,
406
+ MmaPolicy0,
407
+ MmaPolicy1,
408
+ Stages
409
+ >;
410
+
411
+ // Define the epilogue
412
+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
413
+ ThreadblockShape1,
414
+ WarpMmaTensorOp1,
415
+ 1,
416
+ EpilogueOutputOp1,
417
+ EpilogueOutputOp1::kCount,
418
+ InterleavedK
419
+ >::Epilogue;
420
+
421
+ // Define the kernel
422
+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
423
+ B2bMma,
424
+ Epilogue,
425
+ ThreadblockSwizzle,
426
+ conv::Operator::kFprop
427
+ >;
428
+ };
429
+
430
+ /////////////////////////////////////////////////////////////////////////////////////////////////
431
+
432
+ /// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
433
+ /// multistage pipeline.
434
+ /// Accumulator will be staged in shared memory.
435
+ template <
436
+ typename ElementA,
437
+ typename LayoutA,
438
+ typename ElementB,
439
+ typename LayoutB,
440
+ typename ElementC,
441
+ typename LayoutC,
442
+ typename ElementAccumulator,
443
+ typename ArchTag,
444
+ typename ThreadblockShape0,
445
+ typename ThreadblockShape1,
446
+ typename WarpShape0,
447
+ typename WarpShape1,
448
+ typename InstructionShape,
449
+ typename EpilogueOutputOp0,
450
+ typename EpilogueOutputOp1,
451
+ typename ThreadblockSwizzle,
452
+ int Stages,
453
+ typename MathOperatorTag
454
+ >
455
+ struct DefaultB2bConv2dFprop <
456
+ ElementA,
457
+ LayoutA,
458
+ ElementB,
459
+ LayoutB,
460
+ ElementC,
461
+ LayoutC,
462
+ ElementAccumulator,
463
+ arch::OpClassTensorOp,
464
+ ArchTag,
465
+ ThreadblockShape0,
466
+ ThreadblockShape1,
467
+ WarpShape0,
468
+ WarpShape1,
469
+ InstructionShape,
470
+ EpilogueOutputOp0,
471
+ EpilogueOutputOp1,
472
+ ThreadblockSwizzle,
473
+ Stages,
474
+ MathOperatorTag,
475
+ IteratorAlgorithm::kOptimized,
476
+ true
477
+ > {
478
+
479
+ // Define the core components from GEMM
480
+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
481
+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor,
482
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
483
+ Stages, MathOperatorTag>;
484
+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
485
+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor,
486
+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp,
487
+ Stages, MathOperatorTag>;
488
+
489
+ // Define iterators over tiles from the A operand
490
+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA;
491
+ using IteratorA0 =
492
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
493
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
494
+ ElementA, LayoutA,
495
+ ThreadMapA0
496
+ >;
497
+
498
+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
499
+
500
+ // Define iterators over tiles from the B operand
501
+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB;
502
+ using IteratorB0 =
503
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
504
+ cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
505
+ ElementB, LayoutB,
506
+ ThreadMapB0
507
+ >;
508
+
509
+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
510
+
511
+ /// Define iterators over tiles from scale/bias vectors
512
+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
513
+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
514
+ static int const kElementsPerAccess = 2;
515
+ using IteratorAccumulatorScaleBias =
516
+ cutlass::transform::threadblock::VectorIterator<
517
+ cutlass::transform::threadblock::PredicatedVectorAccessIterator<
518
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
519
+ cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
520
+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
521
+ >;
522
+
523
+ // Define iterators over tiles from the B operand
524
+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
525
+ using IteratorB1 =
526
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
527
+ cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
528
+ ElementB, LayoutB,
529
+ ThreadMapB1
530
+ >;
531
+
532
+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
533
+
534
+ // Warp-level GEMM components
535
+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
536
+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
537
+ using MmaPolicy0 = typename MmaCore0::MmaPolicy;
538
+ using MmaPolicy1 = typename MmaCore1::MmaPolicy;
539
+
540
+ // Use fragment iterator for the accumulator
541
+ using SmemAccumulatorLayout = cutlass::layout::RowMajor;
542
+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
543
+ WarpShape0, InstructionShape,
544
+ ElementAccumulator,
545
+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
546
+ SmemAccumulatorLayout
547
+ >;
548
+
549
+ // Store Accumulator tiles to Shared Memory
550
+ using SmemIteratorD0 =
551
+ cutlass::epilogue::warp::TileIteratorTensorOp<
552
+ WarpShape0,
553
+ InstructionShape,
554
+ ElementC,
555
+ SmemAccumulatorLayout
556
+ >;
557
+
558
+ static int const kThreadCount = 32;
559
+ // load warp tile from Shared Memory accumulator
560
+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator<
561
+ MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
562
+ ElementA, SmemAccumulatorLayout,
563
+ MatrixShape<InstructionShape::kM, InstructionShape::kK>,
564
+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
565
+
566
+ // Define the Mma
567
+ using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator<
568
+ ThreadblockShape0,
569
+ IteratorA0,
570
+ SmemIteratorA0,
571
+ arch::CacheOperation::Always,
572
+ IteratorB0,
573
+ SmemIteratorB0,
574
+ arch::CacheOperation::Global,
575
+ IteratorAccumulatorScaleBias,
576
+ FragmentIteratorAccumulator,
577
+ SmemIteratorD0,
578
+ ThreadblockShape1,
579
+ WarpIteratorA1,
580
+ IteratorB1,
581
+ SmemIteratorB1,
582
+ arch::CacheOperation::Global,
583
+ EpilogueOutputOp0,
584
+ MmaPolicy0,
585
+ MmaPolicy1,
586
+ Stages
587
+ >;
588
+
589
+ // Define the epilogue
590
+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp<
591
+ ThreadblockShape1,
592
+ WarpMmaTensorOp1,
593
+ 1,
594
+ EpilogueOutputOp1,
595
+ EpilogueOutputOp1::kCount
596
+ >::Epilogue;
597
+
598
+ // Define the kernel
599
+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
600
+ B2bMma,
601
+ Epilogue,
602
+ ThreadblockSwizzle,
603
+ conv::Operator::kFprop
604
+ >;
605
+ };
606
+
607
+ /////////////////////////////////////////////////////////////////////////////////////////////////
608
+
609
+ /// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
610
+ // multistage pipeline with interleaved layout.
611
+ /// Accumulator will be staged in shared memory.
612
+ template <
613
+ typename ElementA,
614
+ typename ElementB,
615
+ typename ElementC,
616
+ typename LayoutC,
617
+ typename ElementAccumulator,
618
+ typename ArchTag,
619
+ typename ThreadblockShape0,
620
+ typename ThreadblockShape1,
621
+ typename WarpShape0,
622
+ typename WarpShape1,
623
+ typename InstructionShape,
624
+ typename EpilogueOutputOp0,
625
+ typename EpilogueOutputOp1,
626
+ typename ThreadblockSwizzle,
627
+ int Stages,
628
+ typename MathOperatorTag,
629
+ int InterleavedK
630
+ >
631
+ struct DefaultB2bConv2dFprop <
632
+ ElementA,
633
+ layout::TensorNCxHWx<InterleavedK>,
634
+ ElementB,
635
+ layout::TensorCxRSKx<InterleavedK>,
636
+ ElementC,
637
+ LayoutC,
638
+ ElementAccumulator,
639
+ arch::OpClassTensorOp,
640
+ ArchTag,
641
+ ThreadblockShape0,
642
+ ThreadblockShape1,
643
+ WarpShape0,
644
+ WarpShape1,
645
+ InstructionShape,
646
+ EpilogueOutputOp0,
647
+ EpilogueOutputOp1,
648
+ ThreadblockSwizzle,
649
+ Stages,
650
+ MathOperatorTag,
651
+ IteratorAlgorithm::kOptimized,
652
+ true
653
+ > {
654
+
655
+ // Define the core components from GEMM
656
+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
657
+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
658
+ ElementB, layout::RowMajorInterleaved<InterleavedK>,
659
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp,
660
+ Stages, MathOperatorTag, true>;
661
+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
662
+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
663
+ ElementB, layout::RowMajorInterleaved<InterleavedK>,
664
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp,
665
+ Stages, MathOperatorTag, true>;
666
+
667
+ // Define iterators over tiles from the A operand
668
+ // Note GEMM shared memory threadmap is used here because conv global memory
669
+ // layout needs to be mapped to fprop which is similar to the crosswise
670
+ // layout which is used by the interleaved GEMM shared memory threadmap.
671
+ // The Interleaved GEMM global memory layout is similar to the congruous
672
+ // layout.
673
+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA;
674
+ using IteratorA0 =
675
+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<
676
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kK>,
677
+ ElementA, layout::TensorNCxHWx<InterleavedK>,
678
+ ThreadMapA0
679
+ >;
680
+
681
+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA;
682
+
683
+ // Define iterators over tiles from the B operand
684
+ // Note GEMM shared memory threadmap is used here because conv global memory
685
+ // layout needs to be mapped to fprop which is similar to the crosswise
686
+ // layout which is used by the interleaved GEMM shared memory threadmap.
687
+ // The Interleaved GEMM global memory layout is similar to the congruous
688
+ // layout.
689
+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB;
690
+ using IteratorB0 =
691
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
692
+ cutlass::MatrixShape<ThreadblockShape0::kK, ThreadblockShape0::kN>,
693
+ ElementB, layout::TensorCxRSKx<InterleavedK>,
694
+ ThreadMapB0
695
+ >;
696
+
697
+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB;
698
+
699
+ /// Define iterators over tiles from scale/bias vectors
700
+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
701
+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
702
+ static int const kElementsPerAccess = 4;
703
+ using IteratorAccumulatorScaleBias =
704
+ cutlass::transform::threadblock::VectorIterator<
705
+ cutlass::transform::threadblock::PredicatedVectorAccessIterator<
706
+ cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
707
+ cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
708
+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
709
+ >;
710
+
711
+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB;
712
+ using IteratorB1 =
713
+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<
714
+ cutlass::MatrixShape<ThreadblockShape1::kK, ThreadblockShape1::kN>,
715
+ ElementB, layout::TensorCxRSKx<InterleavedK>,
716
+ ThreadMapB1
717
+ >;
718
+
719
+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB;
720
+
721
+
722
+ // Warp-level GEMM components
723
+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp;
724
+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp;
725
+ using MmaPolicy0 = typename MmaCore0::MmaPolicy;
726
+ using MmaPolicy1 = typename MmaCore1::MmaPolicy;
727
+
728
+ // Use fragment iterator for the accumulator
729
+ using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>;
730
+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
731
+ WarpShape0, InstructionShape,
732
+ ElementAccumulator,
733
+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC,
734
+ SmemAccumulatorLayout
735
+ >;
736
+
737
+
738
+ // Store Accumulator tiles to Shared Memory
739
+ using SmemIteratorD0 =
740
+ cutlass::epilogue::warp::TileIteratorTensorOp<
741
+ WarpShape0,
742
+ InstructionShape,
743
+ ElementC,
744
+ SmemAccumulatorLayout
745
+ >;
746
+
747
+ static int const kThreadCount = 32;
748
+ // load warp tile from Shared Memory accumulator
749
+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical<
750
+ MatrixShape<WarpShape1::kM, InstructionShape::kK>, cutlass::gemm::Operand::kA,
751
+ ElementA, SmemAccumulatorLayout,
752
+ MatrixShape<InstructionShape::kM, InstructionShape::kK>,
753
+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>;
754
+
755
+ // Define the Mma
756
+ using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator<
757
+ ThreadblockShape0,
758
+ IteratorA0,
759
+ SmemIteratorA0,
760
+ arch::CacheOperation::Always,
761
+ IteratorB0,
762
+ SmemIteratorB0,
763
+ arch::CacheOperation::Global,
764
+ IteratorAccumulatorScaleBias,
765
+ FragmentIteratorAccumulator,
766
+ SmemIteratorD0,
767
+ ThreadblockShape1,
768
+ WarpIteratorA1,
769
+ IteratorB1,
770
+ SmemIteratorB1,
771
+ arch::CacheOperation::Global,
772
+ EpilogueOutputOp0,
773
+ MmaPolicy0,
774
+ MmaPolicy1,
775
+ Stages
776
+ >;
777
+
778
+ // Define the epilogue
779
+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue<
780
+ ThreadblockShape1,
781
+ WarpMmaTensorOp1,
782
+ 1,
783
+ EpilogueOutputOp1,
784
+ EpilogueOutputOp1::kCount,
785
+ InterleavedK
786
+ >::Epilogue;
787
+
788
+ // Define the kernel
789
+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution<
790
+ B2bMma,
791
+ Epilogue,
792
+ ThreadblockSwizzle,
793
+ conv::Operator::kFprop
794
+ >;
795
+ };
796
+
797
+
798
+ /////////////////////////////////////////////////////////////////////////////////////////////////
799
+
800
+ } // namespace kernel
801
+ } // namespace conv
802
+ } // namespace cutlass
803
+
804
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief
34
+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
35
+ the appropriate threadblock-scoped epilogue.
36
+
37
+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
38
+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial
39
+ specializations here choose 'device::GemmTransposed' to implement this functionality.
40
+ */
41
+
42
+ #pragma once
43
+
44
+ #include "cutlass/cutlass.h"
45
+
46
+ #include "cutlass/layout/matrix.h"
47
+ #include "cutlass/numeric_types.h"
48
+
49
+ #include "cutlass/epilogue/threadblock/epilogue.h"
50
+ #include "cutlass/epilogue/thread/linear_combination.h"
51
+
52
+ #include "cutlass/gemm/gemm.h"
53
+ #include "cutlass/gemm/kernel/gemm_pipelined.h"
54
+ #include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
55
+ #include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
56
+ #include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
57
+ #include "cutlass/gemm/threadblock/default_mma_core_simt.h"
58
+ #include "cutlass/gemm/threadblock/threadblock_swizzle.h"
59
+ #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
60
+ #include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
61
+ #include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
62
+
63
+ #include "cutlass/transform/threadblock/predicated_tile_iterator.h"
64
+
65
+ #include "kernel/b2b_gemm.h"
66
+ #include "kernel/grouped.h"
67
+ #include "threadblock/default_b2b_mma.h"
68
+ #include "threadblock/grouped_threadblock_swizzle.h"
69
+
70
+ ////////////////////////////////////////////////////////////////////////////////
71
+
72
+ namespace cutlass {
73
+ namespace gemm {
74
+ namespace kernel {
75
+
76
+ ////////////////////////////////////////////////////////////////////////////////
77
+
78
+ template <typename T>
79
+ using IsGroupedSwizzle = cutlass::gemm::threadblock::detail::IsGroupedSwizzle<T>;
80
+
81
+ template <
82
+ /// Element type for A matrix operand
83
+ typename ElementA_,
84
+ /// Layout type for A matrix operand
85
+ typename LayoutA_,
86
+ /// Access granularity of A matrix in units of elements
87
+ int kAlignmentA,
88
+ /// Element type for B matrix operand
89
+ typename ElementB_,
90
+ /// Layout type for B matrix operand
91
+ typename LayoutB_,
92
+ /// Access granularity of B matrix in units of elements
93
+ int kAlignmentB,
94
+ /// Element type for C and D matrix operands
95
+ typename ElementC_,
96
+ /// Layout type for C and D matrix operands
97
+ typename LayoutC_,
98
+ /// Element type for internal accumulation
99
+ typename ElementAccumulator,
100
+ /// Operator class tag
101
+ typename OperatorClass,
102
+ /// Tag indicating architecture to tune for
103
+ typename ArchTag,
104
+ /// Threadblock-level tile size (concept: GemmShape)
105
+ typename ThreadblockShape0,
106
+ /// Threadblock-level tile size (concept: GemmShape)
107
+ typename ThreadblockShape1,
108
+ /// Warp-level tile size (concept: GemmShape)
109
+ typename WarpShape0,
110
+ /// Warp-level tile size (concept: GemmShape)
111
+ typename WarpShape1,
112
+ /// Warp-level tile size (concept: GemmShape)
113
+ typename InstructionShape,
114
+ /// Epilogue output operator
115
+ typename EpilogueOutputOp0,
116
+ /// Epilogue output operator
117
+ typename EpilogueOutputOp1,
118
+ /// Threadblock-level swizzling operator
119
+ typename ThreadblockSwizzle,
120
+ /// Number of stages used in the pipelined mainloop
121
+ int Stages,
122
+ /// Operation performed by GEMM
123
+ typename Operator,
124
+ /// Stage accumulator in shared memory
125
+ bool SmemAccumulator = false,
126
+ /// Whether or not the operation is grouped
127
+ typename Enable = void
128
+ >
129
+ struct DefaultB2bGemm;
130
+
131
+ ////////////////////////////////////////////////////////////////////////////////
132
+
133
+ /// Partial specialization for Ampere Architecture
134
+ template <
135
+ /// Element type for A matrix operand
136
+ typename ElementA,
137
+ /// Layout type for A matrix operand
138
+ typename LayoutA,
139
+ /// Access granularity of A matrix in units of elements
140
+ int kAlignmentA,
141
+ /// Element type for B matrix operand
142
+ typename ElementB,
143
+ /// Layout type for B matrix operand
144
+ typename LayoutB,
145
+ /// Access granularity of A matrix in units of elements
146
+ int kAlignmentB,
147
+ /// Element type for C and D matrix operands
148
+ typename ElementC,
149
+ /// Element type for internal accumulation
150
+ typename ElementAccumulator,
151
+ /// Threadblock-level tile size (concept: GemmShape)
152
+ typename ThreadblockShape0,
153
+ /// Threadblock-level tile size (concept: GemmShape)
154
+ typename ThreadblockShape1,
155
+ /// Warp-level tile size (concept: GemmShape)
156
+ typename WarpShape0,
157
+ /// Warp-level tile size (concept: GemmShape)
158
+ typename WarpShape1,
159
+ /// Warp-level tile size (concept: GemmShape)
160
+ typename InstructionShape,
161
+ /// Epilogue output operator
162
+ typename EpilogueOutputOp0,
163
+ /// Epilogue output operator
164
+ typename EpilogueOutputOp1,
165
+ /// Threadblock-level swizzling operator
166
+ typename ThreadblockSwizzle,
167
+ /// Number of stages used in the pipelined mainloop
168
+ int Stages,
169
+ /// Operation performed by GEMM
170
+ typename Operator>
171
+ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
172
+ layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
173
+ arch::Sm80, ThreadblockShape0, ThreadblockShape1,
174
+ WarpShape0, WarpShape1, InstructionShape,
175
+ EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
176
+ Operator, false, typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
177
+ /// Define the threadblock-scoped matrix multiply-accumulate
178
+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
179
+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
180
+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
181
+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
182
+ InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma;
183
+
184
+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
185
+
186
+ /// Define the epilogue
187
+ using Epilogue =
188
+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
189
+ ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
190
+ EpilogueOutputOp1::kCount>::Epilogue;
191
+
192
+ /// Define the kernel-level GEMM operator.
193
+ using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
194
+ };
195
+
196
+ /// Partial specialization for Ampere Architecture with grouped operation
197
+ template <
198
+ /// Element type for A matrix operand
199
+ typename ElementA,
200
+ /// Layout type for A matrix operand
201
+ typename LayoutA,
202
+ /// Access granularity of A matrix in units of elements
203
+ int kAlignmentA,
204
+ /// Element type for B matrix operand
205
+ typename ElementB,
206
+ /// Layout type for B matrix operand
207
+ typename LayoutB,
208
+ /// Access granularity of A matrix in units of elements
209
+ int kAlignmentB,
210
+ /// Element type for C and D matrix operands
211
+ typename ElementC,
212
+ /// Element type for internal accumulation
213
+ typename ElementAccumulator,
214
+ /// Threadblock-level tile size (concept: GemmShape)
215
+ typename ThreadblockShape0,
216
+ /// Threadblock-level tile size (concept: GemmShape)
217
+ typename ThreadblockShape1,
218
+ /// Warp-level tile size (concept: GemmShape)
219
+ typename WarpShape0,
220
+ /// Warp-level tile size (concept: GemmShape)
221
+ typename WarpShape1,
222
+ /// Warp-level tile size (concept: GemmShape)
223
+ typename InstructionShape,
224
+ /// Epilogue output operator
225
+ typename EpilogueOutputOp0,
226
+ /// Epilogue output operator
227
+ typename EpilogueOutputOp1,
228
+ /// Threadblock-level swizzling operator
229
+ typename ThreadblockSwizzle,
230
+ /// Number of stages used in the pipelined mainloop
231
+ int Stages,
232
+ /// Operation performed by GEMM
233
+ typename Operator>
234
+ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
235
+ layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
236
+ arch::Sm80, ThreadblockShape0, ThreadblockShape1,
237
+ WarpShape0, WarpShape1, InstructionShape,
238
+ EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
239
+ Operator, false, typename platform::enable_if<IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
240
+ /// Define the threadblock-scoped matrix multiply-accumulate
241
+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
242
+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
243
+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
244
+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
245
+ InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma;
246
+
247
+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
248
+
249
+ /// Define the epilogue
250
+ using Epilogue =
251
+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
252
+ ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
253
+ EpilogueOutputOp1::kCount>::Epilogue;
254
+
255
+ /// Define the kernel-level GEMM operator.
256
+ using UnderlyingB2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
257
+
258
+ using B2bGemmKernel = kernel::GroupedKernel<UnderlyingB2bGemmKernel>;
259
+ };
260
+
261
+
262
+ ////////////////////////////////////////////////////////////////////////////////
263
+
264
+ /// Partial specialization for Turing Architecture
265
+ template <
266
+ /// Element type for A matrix operand
267
+ typename ElementA,
268
+ /// Layout type for A matrix operand
269
+ typename LayoutA,
270
+ /// Access granularity of A matrix in units of elements
271
+ int kAlignmentA,
272
+ /// Element type for B matrix operand
273
+ typename ElementB,
274
+ /// Layout type for B matrix operand
275
+ typename LayoutB,
276
+ /// Access granularity of B matrix in units of elements
277
+ int kAlignmentB,
278
+ /// Element type for C and D matrix operands
279
+ typename ElementC,
280
+ /// Element type for internal accumulation
281
+ typename ElementAccumulator,
282
+ /// Threadblock-level tile size (concept: GemmShape)
283
+ typename ThreadblockShape0,
284
+ /// Threadblock-level tile size (concept: GemmShape)
285
+ typename ThreadblockShape1,
286
+ /// Warp-level tile size (concept: GemmShape)
287
+ typename WarpShape0,
288
+ /// Warp-level tile size (concept: GemmShape)
289
+ typename WarpShape1,
290
+ /// Warp-level tile size (concept: GemmShape)
291
+ typename InstructionShape,
292
+ /// Epilogue output operator
293
+ typename EpilogueOutputOp0,
294
+ /// Epilogue output operator
295
+ typename EpilogueOutputOp1,
296
+ /// Threadblock-level swizzling operator
297
+ typename ThreadblockSwizzle,
298
+ /// Operation performed by GEMM
299
+ typename Operator
300
+ >
301
+ struct DefaultB2bGemm<
302
+ ElementA, LayoutA, kAlignmentA,
303
+ ElementB, LayoutB, kAlignmentB,
304
+ ElementC, layout::RowMajor,
305
+ ElementAccumulator,
306
+ arch::OpClassTensorOp,
307
+ arch::Sm75,
308
+ ThreadblockShape0,
309
+ ThreadblockShape1,
310
+ WarpShape0,
311
+ WarpShape1,
312
+ InstructionShape,
313
+ EpilogueOutputOp0,
314
+ EpilogueOutputOp1,
315
+ ThreadblockSwizzle,
316
+ 2,
317
+ Operator,
318
+ false,
319
+ typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type
320
+ > {
321
+
322
+ /// Define the threadblock-scoped matrix multiply-accumulate
323
+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
324
+ ElementA,
325
+ LayoutA,
326
+ kAlignmentA,
327
+ ElementB,
328
+ LayoutB,
329
+ kAlignmentB,
330
+ ElementAccumulator,
331
+ layout::RowMajor,
332
+ arch::OpClassTensorOp,
333
+ arch::Sm75,
334
+ ThreadblockShape0,
335
+ ThreadblockShape1,
336
+ WarpShape0,
337
+ WarpShape1,
338
+ InstructionShape,
339
+ 2,
340
+ Operator,
341
+ EpilogueOutputOp0
342
+ >::ThreadblockB2bMma;
343
+
344
+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
345
+
346
+ /// Define the epilogue
347
+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
348
+ ThreadblockShape1,
349
+ typename B2bMma::Operator1,
350
+ kPartitionsK1,
351
+ EpilogueOutputOp1,
352
+ EpilogueOutputOp1::kCount
353
+ >::Epilogue;
354
+
355
+ /// Define the kernel-level GEMM operator.
356
+ using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
357
+ };
358
+
359
+
360
+ /// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout
361
+ template <
362
+ /// Element type for A matrix operand
363
+ typename ElementA,
364
+ /// Access granularity of A matrix in units of elements
365
+ int kAlignmentA,
366
+ /// Element type for B matrix operand
367
+ typename ElementB,
368
+ /// Access granularity of B matrix in units of elements
369
+ int kAlignmentB,
370
+ /// Element type for C and D matrix operands
371
+ typename ElementC,
372
+ /// Threadblock-level tile size (concept: GemmShape)
373
+ typename ThreadblockShape0,
374
+ /// Threadblock-level tile size (concept: GemmShape)
375
+ typename ThreadblockShape1,
376
+ /// Warp-level tile size (concept: GemmShape)
377
+ typename WarpShape0,
378
+ /// Warp-level tile size (concept: GemmShape)
379
+ typename WarpShape1,
380
+ /// Warp-level tile size (concept: GemmShape)
381
+ typename InstructionShape,
382
+ /// Epilogue output operator
383
+ typename EpilogueOutputOp0,
384
+ /// Epilogue output operator
385
+ typename EpilogueOutputOp1,
386
+ /// Threadblock-level swizzling operator
387
+ typename ThreadblockSwizzle,
388
+ /// Number of stages used in the pipelined mainloop
389
+ int Stages,
390
+ /// Number of Interleaved k
391
+ int InterleavedK,
392
+ /// Operation performed by GEMM
393
+ typename Operator>
394
+ struct DefaultB2bGemm<
395
+ ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
396
+ ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
397
+ ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
398
+ arch::OpClassTensorOp, arch::Sm80,
399
+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
400
+ InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
401
+ ThreadblockSwizzle, Stages,
402
+ Operator, false, typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
403
+ using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
404
+ using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
405
+ using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
406
+
407
+ using ElementAccumulator = int32_t;
408
+
409
+ /// Define the threadblock-scoped matrix multiply-accumulate
410
+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
411
+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
412
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80,
413
+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
414
+ InstructionShape, Stages, Operator, EpilogueOutputOp0,
415
+ true>::ThreadblockB2bMma;
416
+
417
+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
418
+
419
+ /// Define the epilogue
420
+ using Epilogue = typename cutlass::epilogue::threadblock::
421
+ DefaultInterleavedEpilogueTensorOp<
422
+ ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
423
+ 64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
424
+
425
+ /// Define the kernel-level GEMM operator.
426
+ using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
427
+ };
428
+
429
+ ////////////////////////////////////////////////////////////////////////////////
430
+
431
+
432
+ /// Partial specialization for Turing Integer Tensor Core Interleaved layout
433
+ template <
434
+ /// Element type for A matrix operand
435
+ typename ElementA,
436
+ /// Access granularity of A matrix in units of elements
437
+ int kAlignmentA,
438
+ /// Element type for B matrix operand
439
+ typename ElementB,
440
+ /// Access granularity of B matrix in units of elements
441
+ int kAlignmentB,
442
+ /// Element type for C and D matrix operands
443
+ typename ElementC,
444
+ /// Threadblock-level tile size (concept: GemmShape)
445
+ typename ThreadblockShape0,
446
+ /// Threadblock-level tile size (concept: GemmShape)
447
+ typename ThreadblockShape1,
448
+ /// Warp-level tile size (concept: GemmShape)
449
+ typename WarpShape0,
450
+ /// Warp-level tile size (concept: GemmShape)
451
+ typename WarpShape1,
452
+ /// Warp-level tile size (concept: GemmShape)
453
+ typename InstructionShape,
454
+ /// Epilogue output operator
455
+ typename EpilogueOutputOp0,
456
+ /// Epilogue output operator
457
+ typename EpilogueOutputOp1,
458
+ /// Threadblock-level swizzling operator
459
+ typename ThreadblockSwizzle,
460
+ /// Number of Interleaved k
461
+ int InterleavedK,
462
+ /// Operation performed by GEMM
463
+ typename Operator>
464
+ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
465
+ kAlignmentA, ElementB,
466
+ layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
467
+ ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
468
+ int32_t, arch::OpClassTensorOp, arch::Sm75,
469
+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
470
+ InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
471
+ ThreadblockSwizzle, 2, Operator, false,
472
+ typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
473
+ using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
474
+ using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
475
+ using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
476
+
477
+ using ElementAccumulator = int32_t;
478
+
479
+ /// Define the threadblock-scoped matrix multiply-accumulate
480
+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
481
+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC,
482
+ arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
483
+ WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true>::ThreadblockB2bMma;
484
+
485
+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
486
+
487
+ /// Define the epilogue for the 2nd Gemm
488
+ using Epilogue = typename cutlass::epilogue::threadblock::
489
+ DefaultInterleavedEpilogueTensorOp<
490
+ ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
491
+ 64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
492
+
493
+ /// Define the kernel-level GEMM operator.
494
+ using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
495
+ };
496
+
497
+ ////////////////////////////////////////////////////////////////////////////////
498
+
499
+ ////////////////////////////////////////////////////////////////////////////////
500
+
501
+ } // namespace kernel
502
+ } // namespace gemm
503
+ } // namespace cutlass
build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief
34
+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
35
+ the appropriate threadblock-scoped epilogue.
36
+
37
+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
38
+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial
39
+ specializations here choose 'device::GemmTransposed' to implement this functionality.
40
+ */
41
+
42
+ #pragma once
43
+
44
+ #include "cutlass/cutlass.h"
45
+
46
+ #include "cutlass/layout/matrix.h"
47
+ #include "cutlass/numeric_types.h"
48
+
49
+ #include "cutlass/epilogue/threadblock/epilogue.h"
50
+ #include "cutlass/epilogue/thread/linear_combination.h"
51
+
52
+ #include "cutlass/gemm/gemm.h"
53
+ #include "cutlass/gemm/kernel/gemm_pipelined.h"
54
+ #include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
55
+ #include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
56
+ #include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
57
+ #include "cutlass/gemm/threadblock/default_mma_core_simt.h"
58
+ #include "cutlass/gemm/threadblock/threadblock_swizzle.h"
59
+ #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
60
+ #include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
61
+ #include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
62
+
63
+ #include "cutlass/transform/threadblock/predicated_tile_iterator.h"
64
+ #include "cutlass/transform/threadblock/vector_iterator.h"
65
+ #include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
66
+
67
+ #include "kernel/b2b_gemm.h"
68
+ #include "threadblock/default_b2b_mma.h"
69
+ #include "threadblock/default_b2b_mma_smem_accumulator.h"
70
+
71
+ ////////////////////////////////////////////////////////////////////////////////
72
+
73
+ namespace cutlass {
74
+ namespace gemm {
75
+ namespace kernel {
76
+
77
+ ////////////////////////////////////////////////////////////////////////////////
78
+
79
+ /// Partial specialization for Ampere Architecture
80
+ template <
81
+ /// Element type for A matrix operand
82
+ typename ElementA,
83
+ /// Layout type for A matrix operand
84
+ typename LayoutA,
85
+ /// Access granularity of A matrix in units of elements
86
+ int kAlignmentA,
87
+ /// Element type for B matrix operand
88
+ typename ElementB,
89
+ /// Layout type for B matrix operand
90
+ typename LayoutB,
91
+ /// Access granularity of A matrix in units of elements
92
+ int kAlignmentB,
93
+ /// Element type for C and D matrix operands
94
+ typename ElementC,
95
+ /// Element type for internal accumulation
96
+ typename ElementAccumulator,
97
+ /// Threadblock-level tile size (concept: GemmShape)
98
+ typename ThreadblockShape0,
99
+ /// Threadblock-level tile size (concept: GemmShape)
100
+ typename ThreadblockShape1,
101
+ /// Warp-level tile size (concept: GemmShape)
102
+ typename WarpShape0,
103
+ /// Warp-level tile size (concept: GemmShape)
104
+ typename WarpShape1,
105
+ /// Warp-level tile size (concept: GemmShape)
106
+ typename InstructionShape,
107
+ /// Epilogue output operator
108
+ typename EpilogueOutputOp0,
109
+ /// Epilogue output operator
110
+ typename EpilogueOutputOp1,
111
+ /// Threadblock-level swizzling operator
112
+ typename ThreadblockSwizzle,
113
+ /// Number of stages used in the pipelined mainloop
114
+ int Stages,
115
+ /// Operation performed by GEMM
116
+ typename Operator>
117
+ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
118
+ layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
119
+ arch::Sm80, ThreadblockShape0, ThreadblockShape1,
120
+ WarpShape0, WarpShape1, InstructionShape,
121
+ EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
122
+ Operator, true> {
123
+ /// Define the threadblock-scoped matrix multiply-accumulate
124
+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
125
+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
126
+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
127
+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
128
+ InstructionShape, Stages, Operator, EpilogueOutputOp0, false, true>::ThreadblockB2bMma;
129
+
130
+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
131
+
132
+ /// Define the epilogue
133
+ using Epilogue =
134
+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
135
+ ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
136
+ EpilogueOutputOp1::kCount>::Epilogue;
137
+
138
+ /// Define the kernel-level GEMM operator.
139
+ using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
140
+ };
141
+
142
+ ////////////////////////////////////////////////////////////////////////////////
143
+
144
+ /// Partial specialization for Turing Architecture
145
+ template <
146
+ /// Element type for A matrix operand
147
+ typename ElementA,
148
+ /// Layout type for A matrix operand
149
+ typename LayoutA,
150
+ /// Access granularity of A matrix in units of elements
151
+ int kAlignmentA,
152
+ /// Element type for B matrix operand
153
+ typename ElementB,
154
+ /// Layout type for B matrix operand
155
+ typename LayoutB,
156
+ /// Access granularity of B matrix in units of elements
157
+ int kAlignmentB,
158
+ /// Element type for C and D matrix operands
159
+ typename ElementC,
160
+ /// Element type for internal accumulation
161
+ typename ElementAccumulator,
162
+ /// Threadblock-level tile size (concept: GemmShape)
163
+ typename ThreadblockShape0,
164
+ /// Threadblock-level tile size (concept: GemmShape)
165
+ typename ThreadblockShape1,
166
+ /// Warp-level tile size (concept: GemmShape)
167
+ typename WarpShape0,
168
+ /// Warp-level tile size (concept: GemmShape)
169
+ typename WarpShape1,
170
+ /// Warp-level tile size (concept: GemmShape)
171
+ typename InstructionShape,
172
+ /// Epilogue output operator
173
+ typename EpilogueOutputOp0,
174
+ /// Epilogue output operator
175
+ typename EpilogueOutputOp1,
176
+ /// Threadblock-level swizzling operator
177
+ typename ThreadblockSwizzle,
178
+ /// Operation performed by GEMM
179
+ typename Operator
180
+ >
181
+ struct DefaultB2bGemm<
182
+ ElementA, LayoutA, kAlignmentA,
183
+ ElementB, LayoutB, kAlignmentB,
184
+ ElementC, layout::RowMajor,
185
+ ElementAccumulator,
186
+ arch::OpClassTensorOp,
187
+ arch::Sm75,
188
+ ThreadblockShape0,
189
+ ThreadblockShape1,
190
+ WarpShape0,
191
+ WarpShape1,
192
+ InstructionShape,
193
+ EpilogueOutputOp0,
194
+ EpilogueOutputOp1,
195
+ ThreadblockSwizzle,
196
+ 2,
197
+ Operator,
198
+ true
199
+ > {
200
+
201
+ /// Define the threadblock-scoped matrix multiply-accumulate
202
+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
203
+ ElementA,
204
+ LayoutA,
205
+ kAlignmentA,
206
+ ElementB,
207
+ LayoutB,
208
+ kAlignmentB,
209
+ ElementAccumulator,
210
+ layout::RowMajor,
211
+ arch::OpClassTensorOp,
212
+ arch::Sm75,
213
+ ThreadblockShape0,
214
+ ThreadblockShape1,
215
+ WarpShape0,
216
+ WarpShape1,
217
+ InstructionShape,
218
+ 2,
219
+ Operator,
220
+ EpilogueOutputOp0,
221
+ false,
222
+ true
223
+ >::ThreadblockB2bMma;
224
+
225
+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
226
+
227
+ /// Define the epilogue
228
+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
229
+ ThreadblockShape1,
230
+ typename B2bMma::Operator1,
231
+ kPartitionsK1,
232
+ EpilogueOutputOp1,
233
+ EpilogueOutputOp1::kCount
234
+ >::Epilogue;
235
+
236
+ /// Define the kernel-level GEMM operator.
237
+ using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
238
+ };
239
+
240
+
241
+ /// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout
242
+ template <
243
+ /// Element type for A matrix operand
244
+ typename ElementA,
245
+ /// Access granularity of A matrix in units of elements
246
+ int kAlignmentA,
247
+ /// Element type for B matrix operand
248
+ typename ElementB,
249
+ /// Access granularity of B matrix in units of elements
250
+ int kAlignmentB,
251
+ /// Element type for C and D matrix operands
252
+ typename ElementC,
253
+ /// Threadblock-level tile size (concept: GemmShape)
254
+ typename ThreadblockShape0,
255
+ /// Threadblock-level tile size (concept: GemmShape)
256
+ typename ThreadblockShape1,
257
+ /// Warp-level tile size (concept: GemmShape)
258
+ typename WarpShape0,
259
+ /// Warp-level tile size (concept: GemmShape)
260
+ typename WarpShape1,
261
+ /// Warp-level tile size (concept: GemmShape)
262
+ typename InstructionShape,
263
+ /// Epilogue output operator
264
+ typename EpilogueOutputOp0,
265
+ /// Epilogue output operator
266
+ typename EpilogueOutputOp1,
267
+ /// Threadblock-level swizzling operator
268
+ typename ThreadblockSwizzle,
269
+ /// Number of stages used in the pipelined mainloop
270
+ int Stages,
271
+ /// Number of Interleaved k
272
+ int InterleavedK,
273
+ /// Operation performed by GEMM
274
+ typename Operator>
275
+ struct DefaultB2bGemm<
276
+ ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
277
+ ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
278
+ ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
279
+ arch::OpClassTensorOp, arch::Sm80,
280
+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
281
+ InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
282
+ ThreadblockSwizzle, Stages,
283
+ Operator, true> {
284
+ using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
285
+ using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
286
+ using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
287
+
288
+ using ElementAccumulator = int32_t;
289
+
290
+ /// Define the threadblock-scoped matrix multiply-accumulate
291
+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
292
+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
293
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80,
294
+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
295
+ InstructionShape, Stages, Operator, EpilogueOutputOp0,
296
+ true, true>::ThreadblockB2bMma;
297
+
298
+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
299
+
300
+ /// Define the epilogue
301
+ using Epilogue = typename cutlass::epilogue::threadblock::
302
+ DefaultInterleavedEpilogueTensorOp<
303
+ ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
304
+ 64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
305
+
306
+ /// Define the kernel-level GEMM operator.
307
+ using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
308
+ };
309
+
310
+ ////////////////////////////////////////////////////////////////////////////////
311
+
312
+
313
+ /// Partial specialization for Turing Integer Tensor Core Interleaved layout
314
+ template <
315
+ /// Element type for A matrix operand
316
+ typename ElementA,
317
+ /// Access granularity of A matrix in units of elements
318
+ int kAlignmentA,
319
+ /// Element type for B matrix operand
320
+ typename ElementB,
321
+ /// Access granularity of B matrix in units of elements
322
+ int kAlignmentB,
323
+ /// Element type for C and D matrix operands
324
+ typename ElementC,
325
+ /// Threadblock-level tile size (concept: GemmShape)
326
+ typename ThreadblockShape0,
327
+ /// Threadblock-level tile size (concept: GemmShape)
328
+ typename ThreadblockShape1,
329
+ /// Warp-level tile size (concept: GemmShape)
330
+ typename WarpShape0,
331
+ /// Warp-level tile size (concept: GemmShape)
332
+ typename WarpShape1,
333
+ /// Warp-level tile size (concept: GemmShape)
334
+ typename InstructionShape,
335
+ /// Epilogue output operator
336
+ typename EpilogueOutputOp0,
337
+ /// Epilogue output operator
338
+ typename EpilogueOutputOp1,
339
+ /// Threadblock-level swizzling operator
340
+ typename ThreadblockSwizzle,
341
+ /// Number of Interleaved k
342
+ int InterleavedK,
343
+ /// Operation performed by GEMM
344
+ typename Operator>
345
+ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
346
+ kAlignmentA, ElementB,
347
+ layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
348
+ ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
349
+ int32_t, arch::OpClassTensorOp, arch::Sm75,
350
+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
351
+ InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
352
+ ThreadblockSwizzle, 2, Operator, true> {
353
+ using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
354
+ using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
355
+ using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
356
+
357
+ using ElementAccumulator = int32_t;
358
+
359
+ /// Define the threadblock-scoped matrix multiply-accumulate
360
+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
361
+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
362
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
363
+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
364
+ InstructionShape, 2, Operator, EpilogueOutputOp0, true, true>::ThreadblockB2bMma;
365
+
366
+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
367
+
368
+ /// Define the epilogue for the 2nd Gemm
369
+ using Epilogue = typename cutlass::epilogue::threadblock::
370
+ DefaultInterleavedEpilogueTensorOp<
371
+ ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
372
+ 64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
373
+
374
+ /// Define the kernel-level GEMM operator.
375
+ using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
376
+ };
377
+
378
+ ////////////////////////////////////////////////////////////////////////////////
379
+
380
+ ////////////////////////////////////////////////////////////////////////////////
381
+
382
+ } // namespace kernel
383
+ } // namespace gemm
384
+ } // namespace cutlass